HubHost
NegotiateHandler.cs
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Net;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Localization;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace Microsoft.Azure.SignalR
{
internal clast NegotiateHandler where THub : Hub
{
private readonly IUserIdProvider _userIdProvider;
private readonly IConnectionRequestIdProvider _connectionRequestIdProvider;
private readonly Func _claimsProvider;
private readonly Func _diagnosticClientFilter;
private readonly IServiceEndpointManager _endpointManager;
private readonly IEndpointRouter _router;
private readonly IBlazorDetector _blazorDetector;
private readonly string _serverName;
private readonly ServerStickyMode _mode;
private readonly bool _enableDetailedErrors;
private readonly int _endpointsCount;
private readonly int? _maxPollInterval;
private readonly int _customHandshakeTimeout;
private readonly string _hubName;
private readonly ILogger _logger;
private readonly Func _transportTypeDetector;
public NegotiateHandler(
IOptions globalHubOptions,
IOptions hubOptions,
IServiceEndpointManager endpointManager,
IEndpointRouter router,
IUserIdProvider userIdProvider,
IServerNameProvider nameProvider,
IConnectionRequestIdProvider connectionRequestIdProvider,
IOptions options,
IBlazorDetector blazorDetector,
ILogger logger)
{
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_endpointManager = endpointManager ?? throw new ArgumentNullException(nameof(endpointManager));
_router = router ?? throw new ArgumentNullException(nameof(router));
_serverName = nameProvider?.GetName();
_userIdProvider = userIdProvider ?? throw new ArgumentNullException(nameof(userIdProvider));
_connectionRequestIdProvider = connectionRequestIdProvider ?? throw new ArgumentNullException(nameof(connectionRequestIdProvider));
_claimsProvider = options?.Value?.ClaimsProvider;
_diagnosticClientFilter = options?.Value?.DiagnosticClientFilter;
_blazorDetector = blazorDetector ?? new DefaultBlazorDetector();
_mode = options.Value.ServerStickyMode;
_enableDetailedErrors = globalHubOptions.Value.EnableDetailedErrors == true;
_endpointsCount = options.Value.Endpoints.Length;
_maxPollInterval = options.Value.MaxPollIntervalInSeconds;
_transportTypeDetector = options.Value.TransportTypeDetector;
_customHandshakeTimeout = GetCustomHandshakeTimeout(hubOptions.Value.HandshakeTimeout ?? globalHubOptions.Value.HandshakeTimeout);
_hubName = typeof(THub).Name;
}
public async Task Process(HttpContext context)
{
var claims = BuildClaims(context);
var request = context.Request;
var cultureName = context.Features.Get()?.RequestCulture.Culture.Name;
var originalPath = GetOriginalPath(request.Path);
var provider = _endpointManager.GetEndpointProvider(_router.GetNegotiateEndpoint(context, _endpointManager.GetEndpoints(_hubName)));
if (provider == null)
{
return null;
}
var queryString = GetQueryString(request.QueryString.HasValue ? request.QueryString.Value.Substring(1) : null, cultureName);
return new NegotiationResponse
{
Url = provider.GetClientEndpoint(_hubName, originalPath, queryString),
AccessToken = await provider.GenerateClientAccessTokenAsync(_hubName, claims),
// Need to set this even though it's technically protocol violation https://github.com/aspnet/SignalR/issues/2133
AvailableTransports = new List()
};
}
private string GetQueryString(string originalQueryString, string cultureName)
{
var clientRequestId = _connectionRequestIdProvider.GetRequestId();
if (clientRequestId != null)
{
clientRequestId = WebUtility.UrlEncode(clientRequestId);
}
var queryString = $"{Constants.QueryParameter.ConnectionRequestId}={clientRequestId}";
if (!string.IsNullOrEmpty(cultureName))
{
queryString += $"&{Constants.QueryParameter.RequestCulture}={cultureName}";
}
return originalQueryString != null
? $"{originalQueryString}&{queryString}"
: queryString;
}
private IEnumerable BuildClaims(HttpContext context)
{
// Make sticky mode required if detect using blazor
var mode = _blazorDetector.IsBlazor(_hubName) ? ServerStickyMode.Required : _mode;
var userId = _userIdProvider.GetUserId(new ServiceHubConnectionContext(context));
var httpTransportType = _transportTypeDetector?.Invoke(context);
return ClaimsUtility.BuildJwtClaims(context.User, userId, GetClaimsProvider(context), _serverName, mode, _enableDetailedErrors, _endpointsCount, _maxPollInterval, IsDiagnosticClient(context), _customHandshakeTimeout, httpTransportType);
}
private Func GetClaimsProvider(HttpContext context)
{
if (_claimsProvider == null)
{
return null;
}
return () => _claimsProvider.Invoke(context);
}
private bool IsDiagnosticClient(HttpContext context)
{
return _diagnosticClientFilter != null && _diagnosticClientFilter(context);
}
private int GetCustomHandshakeTimeout(TimeSpan? handshakeTimeout)
{
if (!handshakeTimeout.HasValue)
{
Log.UseDefaultHandshakeTimeout(_logger);
return Constants.Periods.DefaultHandshakeTimeout;
}
var timeout = (int)handshakeTimeout.Value.TotalSeconds;
// use default handshake timeout
if (timeout == Constants.Periods.DefaultHandshakeTimeout)
{
Log.UseDefaultHandshakeTimeout(_logger);
return Constants.Periods.DefaultHandshakeTimeout;
}
// the custom handshake timeout is invalid, use default hanshake timeout instead
if (timeout Constants.Periods.MaxCustomHandshakeTimeout)
{
Log.FailToSetCustomHandshakeTimeout(_logger, new ArgumentOutOfRangeException(nameof(handshakeTimeout)));
return Constants.Periods.DefaultHandshakeTimeout;
}
// the custom handshake timeout is valid
Log.SucceedToSetCustomHandshakeTimeout(_logger, timeout);
return timeout;
}
private static string GetOriginalPath(string path)
{
path = path.TrimEnd('/');
return path.EndsWith(Constants.Path.Negotiate)
? path.Substring(0, path.Length - Constants.Path.Negotiate.Length)
: string.Empty;
}
private static clast Log
{
private static readonly Action _useDefaultHandshakeTimeout =
LoggerMessage.Define(LogLevel.Information, new EventId(0, "UseDefaultHandshakeTimeout"), "Use default handshake timeout.");
private static readonly Action _succeedToSetCustomHandshakeTimeout =
LoggerMessage.Define(LogLevel.Information, new EventId(1, "SucceedToSetCustomHandshakeTimeout"), "Succeed to set custom handshake timeout: {timeout} seconds.");
private static readonly Action _failToSetCustomHandshakeTimeout =
LoggerMessage.Define(LogLevel.Warning, new EventId(2, "FailToSetCustomHandshakeTimeout"), $"Fail to set custom handshake timeout, use default handshake timeout {Constants.Periods.DefaultHandshakeTimeout} seconds instead. The range of custom handshake timeout should between 1 second to {Constants.Periods.MaxCustomHandshakeTimeout} seconds.");
public static void UseDefaultHandshakeTimeout(ILogger logger)
{
_useDefaultHandshakeTimeout(logger, null);
}
public static void SucceedToSetCustomHandshakeTimeout(ILogger logger, int customHandshakeTimeout)
{
_succeedToSetCustomHandshakeTimeout(logger, customHandshakeTimeout, null);
}
public static void FailToSetCustomHandshakeTimeout(ILogger logger, Exception exception)
{
_failToSetCustomHandshakeTimeout(logger, exception);
}
}
}
}