csharp/Azure/azure-signalr/test/Microsoft.Azure.SignalR.Tests/ServiceEndpointProviderFacts.cs

ServiceEndpointProviderFacts.cs
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.Extensions.Options;
using Xunit;

namespace Microsoft.Azure.SignalR.Tests
{
    public clast ServiceEndpointProviderFacts
    {
        private const string Endpoint = "https://myendpoint";
        private const string AccessKey = "nOu3jXsHnsO5urMumc87M9skQbUWuQ+PE5IvSUEic8w=";
        private static readonly string HubName = nameof(TestHub).ToLower();
        private static readonly string AppName = "testapp";

        private static readonly string ConnectionStringWithoutVersion =
            $"Endpoint={Endpoint};AccessKey={AccessKey};";

        private static readonly string ConnectionStringWithPreviewVersion =
            $"Endpoint={Endpoint};AccessKey={AccessKey};Version=1.0-preview";

        private static readonly string ConnectionStringWithV1Version = $"Endpoint={Endpoint};AccessKey={AccessKey};Version=1.0";

        private static readonly ServiceOptions _optionsWithoutAppName = Options.Create(new ServiceOptions()).Value;
        private static readonly ServiceOptions _optionsWithAppName = Options.Create(new ServiceOptions { ApplicationName = AppName }).Value;

        private static readonly ServiceEndpointProvider[] EndpointProviderArray =
        {
            new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithoutVersion), _optionsWithoutAppName),
            new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithPreviewVersion), _optionsWithoutAppName),
            new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithV1Version), _optionsWithoutAppName)
        };

        private static readonly ServiceEndpointProvider[] EndpointProviderArrayWithPrefix =
        {
            new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithoutVersion), _optionsWithAppName),
            new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithPreviewVersion), _optionsWithAppName),
            new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithV1Version), _optionsWithAppName)
        };

        private static readonly (string path, string queryString, string expectedQuery)[] PathAndQueryArray =
        {
            ("", "", ""),
            (null, "", ""),
            ("/user/path", "", $"&{Constants.QueryParameter.OriginalPath}=%2Fuser%2Fpath"),
            ("", "customKey=customValue", "&customKey=customValue"),
            ("/user/path", "customKey=customValue", $"&{Constants.QueryParameter.OriginalPath}=%2Fuser%2Fpath&customKey=customValue")
        };

        public static IEnumerable DefaultEndpointProviders =>
            EndpointProviderArray.Select(provider => new object[] { provider });

        public static IEnumerable PathAndQueries =>
            PathAndQueryArray.Select(t => new object[] { t.path, t.queryString, t.expectedQuery });

        public static IEnumerable DefaultEndpointProvidersWithPath =>
            from provider in EndpointProviderArray
            from t in PathAndQueryArray
            select new object[] { provider, t.path, t.queryString, t.expectedQuery };

        public static IEnumerable DefaultEndpointProvidersWithPathPlusPrefix =>
            from provider in EndpointProviderArrayWithPrefix
            from t in PathAndQueryArray
            select new object[] { provider, t.path, t.queryString, t.expectedQuery };

        public static IEnumerable DefaultEndpointProvidersPlusPrefix =>
            EndpointProviderArrayWithPrefix.Select(provider => new object[] { provider });

        [Theory]
        [MemberData(nameof(DefaultEndpointProviders))]
        internal void GetServerEndpoint(IServiceEndpointProvider provider)
        {
            var expected = $"{Endpoint}/server/?hub={HubName}";
            var actual = provider.GetServerEndpoint(nameof(TestHub));
            astert.Equal(expected, actual);
        }

        [Theory]
        [MemberData(nameof(DefaultEndpointProvidersWithPath))]
        internal void GetClientEndpoint(IServiceEndpointProvider provider, string path, string queryString, string expectedQueryString)
        {
            var expected = $"{Endpoint}/client/?hub={HubName}{expectedQueryString}";
            var actual = provider.GetClientEndpoint(HubName, path, queryString);
            astert.Equal(expected, actual);
        }

        [Theory]
        [MemberData(nameof(DefaultEndpointProvidersPlusPrefix))]
        internal void GetServerEndpointWithAppName(IServiceEndpointProvider provider)
        {
            var expected = $"{Endpoint}/server/?hub={AppName}_{HubName}";
            var actual = provider.GetServerEndpoint(nameof(TestHub));
            astert.Equal(expected, actual);
        }

        [Theory]
        [MemberData(nameof(DefaultEndpointProvidersWithPathPlusPrefix))]
        internal void GetClientEndpointWithAppName(IServiceEndpointProvider provider, string path, string queryString, string expectedQueryString)
        {
            var expected = $"{Endpoint}/client/?hub={AppName}_{HubName}{expectedQueryString}";
            var actual = provider.GetClientEndpoint(HubName, path, queryString);
            astert.Equal(expected, actual);
        }

        [Fact(Skip = "Access token does not need to be unique")]
        internal async Task GenerateMultipleAccessTokenShouldBeUnique()
        {
            var count = 1000;
            var sep = new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithPreviewVersion), _optionsWithoutAppName);
            var userId = Guid.NewGuid().ToString();
            var tokens = new List();
            for (int i = 0; i < count; i++)
            {
                tokens.Add(await sep.GenerateClientAccessTokenAsync(nameof(TestHub)));
                tokens.Add(await sep.GenerateServerAccessTokenAsync(nameof(TestHub), userId));
            }

            var distinct = tokens.Distinct();
            astert.Equal(tokens.Count, distinct.Count());
        }

        [Theory]
        [MemberData(nameof(DefaultEndpointProviders))]
        internal async Task GenerateServerAccessToken(IServiceEndpointProvider provider)
        {
            const string userId = "UserA";
            var tokenString = await provider.GenerateServerAccessTokenAsync(nameof(TestHub), userId);
            var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

            var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/server/?hub={HubName}",
                new[]
                {
                    new Claim(ClaimTypes.NameIdentifier, userId)
                },
                token.ValidTo,
                token.ValidFrom,
                token.ValidFrom,
                AccessKey
            );

            astert.Equal(expectedTokenString, tokenString);
        }

        [Theory]
        [MemberData(nameof(DefaultEndpointProvidersPlusPrefix))]
        internal async Task GenerateServerAccessTokenWithPrefix(IServiceEndpointProvider provider)
        {
            const string userId = "UserA";
            var tokenString = await provider.GenerateServerAccessTokenAsync(nameof(TestHub), userId);
            var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

            var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/server/?hub={AppName}_{HubName}",
                new[]
                {
                    new Claim(ClaimTypes.NameIdentifier, userId)
                },
                token.ValidTo,
                token.ValidFrom,
                token.ValidFrom,
                AccessKey
            );

            astert.Equal(expectedTokenString, tokenString);
        }

        [Theory]
        [MemberData(nameof(DefaultEndpointProviders))]
        internal async Task GenerateClientAccessToken(IServiceEndpointProvider provider)
        {
            var requestId = Guid.NewGuid().ToString();
            var tokenString = await provider.GenerateClientAccessTokenAsync(HubName);
            var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

            var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/client/?hub={HubName}",
                null,
                token.ValidTo,
                token.ValidFrom,
                token.ValidFrom,
                AccessKey
            );

            astert.Equal(expectedTokenString, tokenString);
        }

        [Theory]
        [MemberData(nameof(DefaultEndpointProvidersPlusPrefix))]
        internal async Task GenerateClientAccessTokenWithPrefix(IServiceEndpointProvider provider)
        {
            var tokenString = await provider.GenerateClientAccessTokenAsync(HubName);
            var token = JwtTokenHelper.JwtHandler.ReadJwtToken(tokenString);

            var expectedTokenString = JwtTokenHelper.GenerateJwtBearer($"{Endpoint}/client/?hub={AppName}_{HubName}",
                null,
                token.ValidTo,
                token.ValidFrom,
                token.ValidFrom,
                AccessKey
            );

            astert.Equal(expectedTokenString, tokenString);
        }

        [Theory]
        [InlineData(AccessTokenAlgorithm.HS256)]
        [InlineData(AccessTokenAlgorithm.HS512)]
        public async Task GenerateServerAccessTokenWithSpecifedAlgorithm(AccessTokenAlgorithm algorithm)
        {
            var provider = new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithV1Version), new ServiceOptions() { AccessTokenAlgorithm = algorithm });
            var serverToken = await provider.GenerateServerAccessTokenAsync("hub1", "user1");

            var token = JwtTokenHelper.JwtHandler.ReadJwtToken(serverToken);

            astert.Equal(algorithm.ToString(), token.SignatureAlgorithm);
        }

        [Theory]
        [InlineData(AccessTokenAlgorithm.HS256)]
        [InlineData(AccessTokenAlgorithm.HS512)]
        public async Task GenerateClientAccessTokenWithSpecifedAlgorithm(AccessTokenAlgorithm algorithm)
        {
            var provider = new ServiceEndpointProvider(new ServiceEndpoint(ConnectionStringWithV1Version), new ServiceOptions() { AccessTokenAlgorithm = algorithm });
            var generatedToken = await provider.GenerateClientAccessTokenAsync("hub1");

            var token = JwtTokenHelper.JwtHandler.ReadJwtToken(generatedToken);

            astert.Equal(algorithm.ToString(), token.SignatureAlgorithm);
        }
    }
}