Skip to content

Commit e7d2bd3

Browse files
Add getBearerToken callback for BYOK providers (Managed Identity) (#1748)
Lets BYOK provider configs supply a getBearerToken callback so the SDK consumer resolves bearer tokens (e.g. Azure Managed Identity) on demand. The callback never crosses the wire: the SDK strips it from the provider config, sends a `hasBearerTokenProvider: true` flag, and answers the runtime's session-scoped `providerToken.getToken` RPC by routing to the matching per-provider callback. The returned token is applied as the Authorization header for outbound model requests; the consumer owns caching/refresh. Implemented across all SDKs (Node, .NET, Go, Java, Python, Rust) with e2e tests. The generated RPC files are intentionally left as the committed CLI 1.0.65 codegen output (providerToken.getToken + hasBearerTokenProvider) rather than hand-edited. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 4d13613 commit e7d2bd3

33 files changed

Lines changed: 3009 additions & 12 deletions

dotnet/src/BearerTokenProvider.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) Microsoft Corporation. All rights reserved.
3+
*--------------------------------------------------------------------------------------------*/
4+
5+
using System.Diagnostics.CodeAnalysis;
6+
7+
namespace GitHub.Copilot;
8+
9+
/// <summary>
10+
/// Arguments passed to a bearer-token callback (the <c>GetBearerToken</c> property
11+
/// on <see cref="ProviderConfig"/> / <see cref="NamedProviderConfig"/>) when the
12+
/// runtime needs a fresh bearer token for a BYOK provider.
13+
/// </summary>
14+
/// <remarks>
15+
/// Part of the experimental managed-identity / bearer-token-provider surface and
16+
/// may change or be removed in future SDK or CLI releases.
17+
/// </remarks>
18+
[Experimental(Diagnostics.Experimental)]
19+
public sealed class ProviderTokenArgs
20+
{
21+
/// <summary>
22+
/// Name of the BYOK provider needing a token. For the singular, whole-session
23+
/// <see cref="ProviderConfig"/> this is the implicit provider name
24+
/// (<c>"default"</c>); for <see cref="NamedProviderConfig"/> entries it is
25+
/// <see cref="NamedProviderConfig.Name"/>.
26+
/// </summary>
27+
/// <remarks>
28+
/// The callback closes over its own token scope/audience; the runtime is
29+
/// provider-agnostic and forwards only the provider name.
30+
/// </remarks>
31+
public required string ProviderName { get; init; }
32+
}

dotnet/src/Client.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ private CopilotSession InitializeSession(
652652
}
653653
ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider);
654654
session.SetCanvasHandler(config.CanvasHandler);
655+
session.RegisterBearerTokenProviders(BuildBearerTokenCallbacks(config));
655656
RegisterSession(session);
656657
session.StartProcessingEvents();
657658
LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null,
@@ -664,6 +665,37 @@ private CopilotSession InitializeSession(
664665
return session;
665666
}
666667

668+
/// <summary>
669+
/// Implicit provider name for the singular, whole-session <see cref="ProviderConfig"/>.
670+
/// </summary>
671+
private const string DefaultBearerTokenProviderName = "default";
672+
673+
/// <summary>
674+
/// Collects the per-provider <c>GetBearerToken</c> callbacks keyed by
675+
/// provider name for session-side registration. The singular, whole-session
676+
/// <see cref="ProviderConfig"/> uses the implicit
677+
/// <see cref="DefaultBearerTokenProviderName"/>.
678+
/// </summary>
679+
private static Dictionary<string, Func<ProviderTokenArgs, Task<string>>> BuildBearerTokenCallbacks(SessionConfigBase config)
680+
{
681+
var callbacks = new Dictionary<string, Func<ProviderTokenArgs, Task<string>>>(StringComparer.Ordinal);
682+
if (config.Provider?.GetBearerToken is { } singular)
683+
{
684+
callbacks[DefaultBearerTokenProviderName] = singular;
685+
}
686+
if (config.Providers != null)
687+
{
688+
foreach (var provider in config.Providers)
689+
{
690+
if (provider.GetBearerToken is { } callback)
691+
{
692+
callbacks[provider.Name] = callback;
693+
}
694+
}
695+
}
696+
return callbacks;
697+
}
698+
667699
/// <summary>
668700
/// Catches misuse of <see cref="SessionConfigBase.AvailableTools"/> /
669701
/// <see cref="SessionConfigBase.ExcludedTools"/> at the SDK boundary so

dotnet/src/Session.cs

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ public sealed partial class CopilotSession : IAsyncDisposable
5858
{
5959
private readonly Dictionary<string, AIFunction> _toolHandlers = [];
6060
private readonly Dictionary<string, Func<CommandContext, Task>> _commandHandlers = [];
61+
private readonly Dictionary<string, Func<ProviderTokenArgs, Task<string>>> _bearerTokenProviders = new(StringComparer.Ordinal);
6162
private readonly ILogger _logger;
6263
private readonly CopilotClient _parentClient;
6364

@@ -76,9 +77,7 @@ private sealed record EventSubscription(Type EventType, Action<SessionEvent> Han
7677
private Dictionary<string, Func<string, Task<string>>>? _transformCallbacks;
7778
private readonly SemaphoreSlim _transformCallbacksLock = new(1, 1);
7879

79-
#pragma warning disable GHCP001
8080
private IReadOnlyList<OpenCanvasInstance> _openCanvases = Array.Empty<OpenCanvasInstance>();
81-
#pragma warning restore GHCP001
8281

8382
private int _isDisposed;
8483

@@ -126,7 +125,6 @@ public SessionCapabilities Capabilities
126125
private set;
127126
}
128127

129-
#pragma warning disable GHCP001
130128
/// <summary>
131129
/// Canvas instances currently known to be open for this session.
132130
/// </summary>
@@ -136,7 +134,6 @@ public SessionCapabilities Capabilities
136134
/// </remarks>
137135
[Experimental(Diagnostics.Experimental)]
138136
public IReadOnlyList<OpenCanvasInstance> OpenCanvases => _openCanvases;
139-
#pragma warning restore GHCP001
140137

141138
/// <summary>
142139
/// Gets the UI API for eliciting information from the user during this session.
@@ -873,6 +870,51 @@ internal void RegisterAutoModeSwitchHandler(Func<AutoModeSwitchRequest, AutoMode
873870
_autoModeSwitchHandler = handler;
874871
}
875872

873+
/// <summary>
874+
/// Registers per-provider <c>GetBearerToken</c> callbacks for BYOK
875+
/// providers configured with managed-identity / on-demand bearer-token auth.
876+
/// </summary>
877+
/// <remarks>
878+
/// The runtime never receives the callback itself; the SDK strips it from the
879+
/// provider config and instead sends <c>hasBearerTokenProvider: true</c>. When
880+
/// the runtime needs a token it issues a session-scoped
881+
/// <c>providerToken.getToken</c> request, which this handler routes to the
882+
/// matching per-provider callback.
883+
/// </remarks>
884+
/// <param name="providers">Map of provider name to callback, or null/empty to clear.</param>
885+
internal void RegisterBearerTokenProviders(IReadOnlyDictionary<string, Func<ProviderTokenArgs, Task<string>>>? providers)
886+
{
887+
_bearerTokenProviders.Clear();
888+
if (providers is null || providers.Count == 0)
889+
{
890+
ClientSessionApis.ProviderToken = null;
891+
return;
892+
}
893+
foreach (var (name, callback) in providers)
894+
{
895+
_bearerTokenProviders[name] = callback;
896+
}
897+
ClientSessionApis.ProviderToken = new BearerTokenProviderHandler(this);
898+
}
899+
900+
/// <summary>
901+
/// Routes runtime <c>providerToken.getToken</c> requests to the matching
902+
/// per-provider <c>GetBearerToken</c> callback registered on the session.
903+
/// </summary>
904+
private sealed class BearerTokenProviderHandler(CopilotSession session) : IProviderTokenHandler
905+
{
906+
public async Task<ProviderTokenAcquireResult> GetTokenAsync(ProviderTokenAcquireRequest request, CancellationToken cancellationToken = default)
907+
{
908+
if (!session._bearerTokenProviders.TryGetValue(request.ProviderName, out var callback))
909+
{
910+
throw new InvalidOperationException(
911+
$"No bearer-token provider registered for provider \"{request.ProviderName}\"");
912+
}
913+
var token = await callback(new ProviderTokenArgs { ProviderName = request.ProviderName }).ConfigureAwait(false);
914+
return new ProviderTokenAcquireResult { Token = token };
915+
}
916+
}
917+
876918
/// <summary>
877919
/// Sets the capabilities reported by the host for this session.
878920
/// </summary>
@@ -882,7 +924,6 @@ internal void SetCapabilities(SessionCapabilities? capabilities)
882924
Capabilities = capabilities ?? new SessionCapabilities();
883925
}
884926

885-
#pragma warning disable GHCP001
886927
internal void SetOpenCanvases(IList<OpenCanvasInstance>? canvases)
887928
{
888929
_openCanvases = canvases is { Count: > 0 }
@@ -959,7 +1000,6 @@ private static JsonElement SerializeActionResult(object? value)
9591000
var element = CopilotClient.ToJsonElementForWire(value);
9601001
return element ?? NullJsonElement;
9611002
}
962-
#pragma warning restore GHCP001
9631003

9641004
private sealed class CanvasHandlerAdapter(ICanvasHandler handler) : Rpc.ICanvasHandler
9651005
{

dotnet/src/Types.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
using GitHub.Copilot.Rpc;
66
using Microsoft.Extensions.AI;
77
using Microsoft.Extensions.Logging;
8+
using System;
89
using System.ComponentModel;
910
using System.Diagnostics;
1011
using System.Diagnostics.CodeAnalysis;
1112
using System.Text.Json;
1213
using System.Text.Json.Nodes;
1314
using System.Text.Json.Serialization;
15+
using System.Threading.Tasks;
1416

1517
namespace GitHub.Copilot;
1618

@@ -2041,6 +2043,28 @@ public sealed class ProviderConfig
20412043
[JsonPropertyName("bearerToken")]
20422044
public string? BearerToken { get; set; }
20432045

2046+
/// <summary>
2047+
/// Wire-only flag, emitted automatically when <see cref="GetBearerToken"/> is set, that tells
2048+
/// the runtime to request a token over the session-scoped <c>providerToken.getToken</c> RPC
2049+
/// before each outbound request to this provider. Derived from <see cref="GetBearerToken"/>;
2050+
/// internal and never part of the public API.
2051+
/// </summary>
2052+
[JsonInclude]
2053+
[JsonPropertyName("hasBearerTokenProvider")]
2054+
internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null;
2055+
2056+
/// <summary>
2057+
/// Per-request callback that resolves a bearer token on demand for this BYOK provider (for
2058+
/// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a
2059+
/// callback backed by your own identity library. Never serialized — setting it makes the SDK send
2060+
/// <c>hasBearerTokenProvider: true</c> on the wire and answer the runtime's
2061+
/// <c>providerToken.getToken</c> requests. Mutually exclusive with <see cref="ApiKey"/> and
2062+
/// <see cref="BearerToken"/>.
2063+
/// </summary>
2064+
[JsonIgnore]
2065+
[Experimental(Diagnostics.Experimental)]
2066+
public Func<ProviderTokenArgs, Task<string>>? GetBearerToken { get; set; }
2067+
20442068
/// <summary>
20452069
/// Azure-specific configuration options.
20462070
/// </summary>
@@ -2173,6 +2197,28 @@ public sealed class NamedProviderConfig
21732197
[JsonPropertyName("bearerToken")]
21742198
public string? BearerToken { get; set; }
21752199

2200+
/// <summary>
2201+
/// Wire-only flag, emitted automatically when <see cref="GetBearerToken"/> is set, that tells
2202+
/// the runtime to request a token over the session-scoped <c>providerToken.getToken</c> RPC
2203+
/// before each outbound request to this provider. Derived from <see cref="GetBearerToken"/>;
2204+
/// internal and never part of the public API.
2205+
/// </summary>
2206+
[JsonInclude]
2207+
[JsonPropertyName("hasBearerTokenProvider")]
2208+
internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null;
2209+
2210+
/// <summary>
2211+
/// Per-request callback that resolves a bearer token on demand for this BYOK provider (for
2212+
/// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a
2213+
/// callback backed by your own identity library. Never serialized — setting it makes the SDK send
2214+
/// <c>hasBearerTokenProvider: true</c> on the wire and answer the runtime's
2215+
/// <c>providerToken.getToken</c> requests. Mutually exclusive with <see cref="ApiKey"/> and
2216+
/// <see cref="BearerToken"/>.
2217+
/// </summary>
2218+
[JsonIgnore]
2219+
[Experimental(Diagnostics.Experimental)]
2220+
public Func<ProviderTokenArgs, Task<string>>? GetBearerToken { get; set; }
2221+
21762222
/// <summary>
21772223
/// Azure-specific configuration options.
21782224
/// </summary>

0 commit comments

Comments
 (0)