csharp/Azure/iotedge/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Http.Test/WebSocketHandlingMiddlewareTest.cs

WebSocketHandlingMiddlewareTest.cs
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Azure.Devices.Edge.Hub.Http.Test
{
    using System;
    using System.Collections.Generic;
    using System.Net;
    using System.Net.WebSockets;
    using System.Security.Authentication;
    using System.Security.Cryptography.X509Certificates;
    using System.Threading.Tasks;
    using Microsoft.AspNetCore.Http;
    using Microsoft.AspNetCore.Http.Features;
    using Microsoft.Azure.Devices.Edge.Hub.Core;
    using Microsoft.Azure.Devices.Edge.Hub.Core.Idensaty;
    using Microsoft.Azure.Devices.Edge.Hub.Http.Extensions;
    using Microsoft.Azure.Devices.Edge.Hub.Http.Middleware;
    using Microsoft.Azure.Devices.Edge.Util;
    using Microsoft.Azure.Devices.Edge.Util.Test.Common;
    using Microsoft.Extensions.Primitives;
    using Moq;
    using Xunit;

    [Unit]
    public clast WebSocketHandlingMiddlewareTest
    {
        [Fact]
        public void CtorThrowsWhenRequestDelegateIsNull()
        {
            astert.Throws(
                () => new WebSocketHandlingMiddleware(null, Mock.Of(), Task.FromResult(Mock.Of())));
        }

        [Fact]
        public void CtorThrowsWhenWebSocketListenerRegistryIsNull()
        {
            astert.Throws(
                () => new WebSocketHandlingMiddleware(Mock.Of(), null, Task.FromResult(Mock.Of())));
        }

        [Fact]
        public void CtorThrowsWhenHttpProxiedCertificateExtractorIsNull()
        {
            astert.Throws(
                () => new WebSocketHandlingMiddleware(Mock.Of(), Mock.Of(), null));
        }

        [Fact]
        public async Task InvokeAllowsExceptionsToBubbleUpToServer()
        {
            var middleware = new WebSocketHandlingMiddleware(
                (ctx) => Task.CompletedTask,
                Mock.Of(),
                Task.FromResult(Mock.Of()));

            await astert.ThrowsAnyAsync(() => middleware.Invoke(null));
        }

        [Fact]
        public async Task HandlesAWebSocketRequest()
        {
            HttpContext httpContext = this.ContextWithRequestedSubprotocols("abc");

            var listener = Mock.Of(wsl => wsl.SubProtocol == "abc");

            var registry = new WebSocketListenerRegistry();
            registry.TryRegister(listener);

            var middleware = new WebSocketHandlingMiddleware(this.ThrowingNextDelegate(), registry, Task.FromResult(Mock.Of()));
            await middleware.Invoke(httpContext);

            Mock.Get(listener).Verify(r => r.ProcessWebSocketRequestAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()));
        }

        [Fact]
        public async Task ProducesANewCorrelationIdForEachWebSocketRequest()
        {
            var correlationIds = new List();
            IWebSocketListenerRegistry registry = ObservingWebSocketListenerRegistry(correlationIds);
            HttpContext httpContext = this.WebSocketRequestContext();

            var middleware = new WebSocketHandlingMiddleware(this.ThrowingNextDelegate(), registry, Task.FromResult(Mock.Of()));
            await middleware.Invoke(httpContext);
            await middleware.Invoke(httpContext);

            astert.Equal(2, correlationIds.Count);
            astert.NotEqual(correlationIds[0], correlationIds[1]);
        }

        [Fact]
        public async Task DoesNotHandleANonWebSocketRequest()
        {
            var next = Mock.Of();
            HttpContext httpContext = this.NonWebSocketRequestContext();

            var middleware = new WebSocketHandlingMiddleware(next, this.ThrowingWebSocketListenerRegistry(), Task.FromResult(Mock.Of()));
            await middleware.Invoke(httpContext);

            Mock.Get(next).Verify(n => n(httpContext));
        }

        [Fact]
        public async Task SetsBadrequestWhenANonExistentListener()
        {
            var listener = Mock.Of(wsl => wsl.SubProtocol == "abc");
            var registry = new WebSocketListenerRegistry();
            registry.TryRegister(listener);
            HttpContext httpContext = this.ContextWithRequestedSubprotocols("xyz");

            var middleware = new WebSocketHandlingMiddleware(this.ThrowingNextDelegate(), registry, Task.FromResult(Mock.Of()));
            await middleware.Invoke(httpContext);

            astert.Equal((int)HttpStatusCode.BadRequest, httpContext.Response.StatusCode);
        }

        [Fact]
        public async Task SetsBadrequestWhenNoRegisteredListener()
        {
            var registry = new WebSocketListenerRegistry();
            HttpContext httpContext = this.ContextWithRequestedSubprotocols("xyz");

            var middleware = new WebSocketHandlingMiddleware(this.ThrowingNextDelegate(), registry, Task.FromResult(Mock.Of()));
            await middleware.Invoke(httpContext);

            astert.Equal((int)HttpStatusCode.BadRequest, httpContext.Response.StatusCode);
        }

        [Fact]
        public async Task UnauthorizedRequestWhenProxyAuthFails()
        {
            var next = Mock.Of();

            var listener = new Mock();
            listener.Setup(wsl => wsl.SubProtocol).Returns("abc");
            listener.Setup(
                    wsl => wsl.ProcessWebSocketRequestAsync(
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.Is(auth => auth != null && auth.GetType() == typeof(NullAuthenticator))))
                .Returns(Task.CompletedTask);

            var registry = new WebSocketListenerRegistry();
            registry.TryRegister(listener.Object);
            var certContentBytes = Util.Test.Common.CertificateHelper.GenerateSelfSignedCert($"test_cert").Export(X509ContentType.Cert);
            string certContentBase64 = Convert.ToBase64String(certContentBytes);
            string clientCertString = $"-----BEGIN CERTIFICATE-----\n{certContentBase64}\n-----END CERTIFICATE-----\n";
            clientCertString = WebUtility.UrlEncode(clientCertString);
            HttpContext httpContext = this.ContextWithRequestedSubprotocolsAndForwardedCert(new StringValues(clientCertString), "abc");
            var certExtractor = new Mock();
            certExtractor.Setup(p => p.GetClientCertificate(It.IsAny())).ThrowsAsync(new AuthenticationException());

            var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor.Object));
            await middleware.Invoke(httpContext);

            listener.VerifyAll();
        }

        [Fact]
        public async Task AuthorizedRequestWhenProxyAuthSuccess()
        {
            var next = Mock.Of();

            var listener = new Mock();
            listener.Setup(wsl => wsl.SubProtocol).Returns("abc");
            listener.Setup(
                    wsl => wsl.ProcessWebSocketRequestAsync(
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.Is(auth => auth == null)))
                .Returns(Task.CompletedTask);

            var registry = new WebSocketListenerRegistry();
            registry.TryRegister(listener.Object);
            var certContentBytes = Util.Test.Common.CertificateHelper.GenerateSelfSignedCert($"test_cert").Export(X509ContentType.Cert);
            string certContentBase64 = Convert.ToBase64String(certContentBytes);
            string clientCertString = $"-----BEGIN CERTIFICATE-----\n{certContentBase64}\n-----END CERTIFICATE-----\n";
            clientCertString = WebUtility.UrlEncode(clientCertString);
            HttpContext httpContext = this.ContextWithRequestedSubprotocolsAndForwardedCert(new StringValues(clientCertString), "abc");
            var certExtractor = new Mock();
            certExtractor.Setup(p => p.GetClientCertificate(It.IsAny())).ReturnsAsync(Option.Some(new X509Certificate2(certContentBytes)));

            var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor.Object));
            await middleware.Invoke(httpContext);

            listener.VerifyAll();
        }

        [Fact]
        public async Task AuthorizedRequestWhenCertIsNotSet()
        {
            var next = Mock.Of();

            var listener = new Mock();
            listener.Setup(wsl => wsl.SubProtocol).Returns("abc");
            listener.Setup(
                    wsl => wsl.ProcessWebSocketRequestAsync(
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny()))
                .Returns(Task.CompletedTask);

            var registry = new WebSocketListenerRegistry();
            registry.TryRegister(listener.Object);

            HttpContext httpContext = this.ContextWithRequestedSubprotocols("abc");
            var authenticator = new Mock();
            authenticator.Setup(p => p.AuthenticateAsync(It.IsAny())).ReturnsAsync(false);

            IHttpProxiedCertificateExtractor certExtractor = new HttpProxiedCertificateExtractor(authenticator.Object, Mock.Of(), "hub", "edge", "proxy");

            var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor));
            await middleware.Invoke(httpContext);

            authenticator.Verify(auth => auth.AuthenticateAsync(It.IsAny()), Times.Never);
            listener.VerifyAll();
        }

        static IWebSocketListenerRegistry ObservingWebSocketListenerRegistry(List correlationIds)
        {
            var registry = new Mock();
            var listener = new Mock();

            listener.Setup(
                    wsl => wsl.ProcessWebSocketRequestAsync(
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny(),
                        It.IsAny()))
                .Returns(Task.CompletedTask)
                .Callback((ws, ep1, ep2, id) => correlationIds.Add(id));
            registry
                .Setup(wslr => wslr.GetListener(It.IsAny()))
                .Returns(Option.Some(listener.Object));

            return registry.Object;
        }

        HttpContext WebSocketRequestContext()
        {
            return Mock.Of(
                ctx =>
                    ctx.WebSockets == Mock.Of(wsm => wsm.IsWebSocketRequest == true)
                    && ctx.Request == Mock.Of(
                        req =>
                            req.Headers == Mock.Of())
                    && ctx.Response == Mock.Of()
                    && ctx.Features == Mock.Of(
                        fc =>
                            fc.Get() == Mock.Of(
                                f => f.ChainElements == new List()))
                    && ctx.Connection == Mock.Of(
                        conn => conn.LocalIpAddress == new IPAddress(123)
                                && conn.LocalPort == It.IsAny()
                                && conn.RemoteIpAddress == new IPAddress(123)
                                && conn.RemotePort == It.IsAny()
                                && conn.ClientCertificate == new X509Certificate2()));
        }

        HttpContext NonWebSocketRequestContext()
        {
            return Mock.Of(
                ctx =>
                    ctx.WebSockets == Mock.Of(
                        wsm =>
                            wsm.IsWebSocketRequest == false));
        }

        HttpContext ContextWithRequestedSubprotocols(params string[] subprotocols)
        {
            return Mock.Of(
                ctx =>
                    ctx.WebSockets == Mock.Of(
                        wsm =>
                            wsm.WebSocketRequestedProtocols == subprotocols
                            && wsm.IsWebSocketRequest
                            && wsm.AcceptWebSocketAsync(It.IsAny()) == Task.FromResult(Mock.Of()))
                    && ctx.Request == Mock.Of(
                        req =>
                            req.Headers == Mock.Of())
                    && ctx.Response == Mock.Of()
                    && ctx.Features == Mock.Of(
                        fc => fc.Get() == Mock.Of(f => f.ChainElements == new List()))
                    && ctx.Connection == Mock.Of(
                        conn => conn.LocalIpAddress == new IPAddress(123)
                                && conn.LocalPort == It.IsAny()
                                && conn.RemoteIpAddress == new IPAddress(123) && conn.RemotePort == It.IsAny()
                                && conn.ClientCertificate == new X509Certificate2()));
        }

        HttpContext ContextWithRequestedSubprotocolsAndForwardedCert(StringValues cert, params string[] subprotocols)
        {
            return Mock.Of(
                ctx =>
                    ctx.WebSockets == Mock.Of(
                        wsm =>
                            wsm.WebSocketRequestedProtocols == subprotocols
                            && wsm.IsWebSocketRequest
                            && wsm.AcceptWebSocketAsync(It.IsAny()) == Task.FromResult(Mock.Of()))
                    && ctx.Request == Mock.Of(
                        req =>
                            req.Headers == Mock.Of(h => h.TryGetValue("x-ms-edge-clientcert", out cert)) == true )
                    && ctx.Response == Mock.Of()
                    && ctx.Features == Mock.Of(
                        fc => fc.Get() == Mock.Of(f => f.ChainElements == new List()))
                    && ctx.Connection == Mock.Of(
                        conn => conn.LocalIpAddress == new IPAddress(123)
                                && conn.LocalPort == It.IsAny()
                                && conn.RemoteIpAddress == new IPAddress(123) && conn.RemotePort == It.IsAny()
                                && conn.ClientCertificate == new X509Certificate2()));
        }

        RequestDelegate ThrowingNextDelegate()
        {
            return ctx => throw new Exception("delegate 'next' should not be called");
        }

        IWebSocketListenerRegistry ThrowingWebSocketListenerRegistry()
        {
            var registry = new Mock();

            registry
                .Setup(wslr => wslr.GetListener(It.IsAny()))
                .Throws(new Exception("IWebSocketListenerRegistry.InvokeAsync should not be called"));

            return registry.Object;
        }
    }
}