Microsoft.Azure.Devices.Edge.Hub.Http.Test
WebSocketHandlingMiddlewareTest.cs
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Azure.Devices.Edge.Hub.Http.Test
{
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.WebSockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Idensaty;
using Microsoft.Azure.Devices.Edge.Hub.Http.Extensions;
using Microsoft.Azure.Devices.Edge.Hub.Http.Middleware;
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Azure.Devices.Edge.Util.Test.Common;
using Microsoft.Extensions.Primitives;
using Moq;
using Xunit;
[Unit]
public clast WebSocketHandlingMiddlewareTest
{
[Fact]
public void CtorThrowsWhenRequestDelegateIsNull()
{
astert.Throws(
() => new WebSocketHandlingMiddleware(null, Mock.Of(), Task.FromResult(Mock.Of())));
}
[Fact]
public void CtorThrowsWhenWebSocketListenerRegistryIsNull()
{
astert.Throws(
() => new WebSocketHandlingMiddleware(Mock.Of(), null, Task.FromResult(Mock.Of())));
}
[Fact]
public void CtorThrowsWhenHttpProxiedCertificateExtractorIsNull()
{
astert.Throws(
() => new WebSocketHandlingMiddleware(Mock.Of(), Mock.Of(), null));
}
[Fact]
public async Task InvokeAllowsExceptionsToBubbleUpToServer()
{
var middleware = new WebSocketHandlingMiddleware(
(ctx) => Task.CompletedTask,
Mock.Of(),
Task.FromResult(Mock.Of()));
await astert.ThrowsAnyAsync(() => middleware.Invoke(null));
}
[Fact]
public async Task HandlesAWebSocketRequest()
{
HttpContext httpContext = this.ContextWithRequestedSubprotocols("abc");
var listener = Mock.Of(wsl => wsl.SubProtocol == "abc");
var registry = new WebSocketListenerRegistry();
registry.TryRegister(listener);
var middleware = new WebSocketHandlingMiddleware(this.ThrowingNextDelegate(), registry, Task.FromResult(Mock.Of()));
await middleware.Invoke(httpContext);
Mock.Get(listener).Verify(r => r.ProcessWebSocketRequestAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()));
}
[Fact]
public async Task ProducesANewCorrelationIdForEachWebSocketRequest()
{
var correlationIds = new List();
IWebSocketListenerRegistry registry = ObservingWebSocketListenerRegistry(correlationIds);
HttpContext httpContext = this.WebSocketRequestContext();
var middleware = new WebSocketHandlingMiddleware(this.ThrowingNextDelegate(), registry, Task.FromResult(Mock.Of()));
await middleware.Invoke(httpContext);
await middleware.Invoke(httpContext);
astert.Equal(2, correlationIds.Count);
astert.NotEqual(correlationIds[0], correlationIds[1]);
}
[Fact]
public async Task DoesNotHandleANonWebSocketRequest()
{
var next = Mock.Of();
HttpContext httpContext = this.NonWebSocketRequestContext();
var middleware = new WebSocketHandlingMiddleware(next, this.ThrowingWebSocketListenerRegistry(), Task.FromResult(Mock.Of()));
await middleware.Invoke(httpContext);
Mock.Get(next).Verify(n => n(httpContext));
}
[Fact]
public async Task SetsBadrequestWhenANonExistentListener()
{
var listener = Mock.Of(wsl => wsl.SubProtocol == "abc");
var registry = new WebSocketListenerRegistry();
registry.TryRegister(listener);
HttpContext httpContext = this.ContextWithRequestedSubprotocols("xyz");
var middleware = new WebSocketHandlingMiddleware(this.ThrowingNextDelegate(), registry, Task.FromResult(Mock.Of()));
await middleware.Invoke(httpContext);
astert.Equal((int)HttpStatusCode.BadRequest, httpContext.Response.StatusCode);
}
[Fact]
public async Task SetsBadrequestWhenNoRegisteredListener()
{
var registry = new WebSocketListenerRegistry();
HttpContext httpContext = this.ContextWithRequestedSubprotocols("xyz");
var middleware = new WebSocketHandlingMiddleware(this.ThrowingNextDelegate(), registry, Task.FromResult(Mock.Of()));
await middleware.Invoke(httpContext);
astert.Equal((int)HttpStatusCode.BadRequest, httpContext.Response.StatusCode);
}
[Fact]
public async Task UnauthorizedRequestWhenProxyAuthFails()
{
var next = Mock.Of();
var listener = new Mock();
listener.Setup(wsl => wsl.SubProtocol).Returns("abc");
listener.Setup(
wsl => wsl.ProcessWebSocketRequestAsync(
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.Is(auth => auth != null && auth.GetType() == typeof(NullAuthenticator))))
.Returns(Task.CompletedTask);
var registry = new WebSocketListenerRegistry();
registry.TryRegister(listener.Object);
var certContentBytes = Util.Test.Common.CertificateHelper.GenerateSelfSignedCert($"test_cert").Export(X509ContentType.Cert);
string certContentBase64 = Convert.ToBase64String(certContentBytes);
string clientCertString = $"-----BEGIN CERTIFICATE-----\n{certContentBase64}\n-----END CERTIFICATE-----\n";
clientCertString = WebUtility.UrlEncode(clientCertString);
HttpContext httpContext = this.ContextWithRequestedSubprotocolsAndForwardedCert(new StringValues(clientCertString), "abc");
var certExtractor = new Mock();
certExtractor.Setup(p => p.GetClientCertificate(It.IsAny())).ThrowsAsync(new AuthenticationException());
var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor.Object));
await middleware.Invoke(httpContext);
listener.VerifyAll();
}
[Fact]
public async Task AuthorizedRequestWhenProxyAuthSuccess()
{
var next = Mock.Of();
var listener = new Mock();
listener.Setup(wsl => wsl.SubProtocol).Returns("abc");
listener.Setup(
wsl => wsl.ProcessWebSocketRequestAsync(
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.Is(auth => auth == null)))
.Returns(Task.CompletedTask);
var registry = new WebSocketListenerRegistry();
registry.TryRegister(listener.Object);
var certContentBytes = Util.Test.Common.CertificateHelper.GenerateSelfSignedCert($"test_cert").Export(X509ContentType.Cert);
string certContentBase64 = Convert.ToBase64String(certContentBytes);
string clientCertString = $"-----BEGIN CERTIFICATE-----\n{certContentBase64}\n-----END CERTIFICATE-----\n";
clientCertString = WebUtility.UrlEncode(clientCertString);
HttpContext httpContext = this.ContextWithRequestedSubprotocolsAndForwardedCert(new StringValues(clientCertString), "abc");
var certExtractor = new Mock();
certExtractor.Setup(p => p.GetClientCertificate(It.IsAny())).ReturnsAsync(Option.Some(new X509Certificate2(certContentBytes)));
var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor.Object));
await middleware.Invoke(httpContext);
listener.VerifyAll();
}
[Fact]
public async Task AuthorizedRequestWhenCertIsNotSet()
{
var next = Mock.Of();
var listener = new Mock();
listener.Setup(wsl => wsl.SubProtocol).Returns("abc");
listener.Setup(
wsl => wsl.ProcessWebSocketRequestAsync(
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.IsAny()))
.Returns(Task.CompletedTask);
var registry = new WebSocketListenerRegistry();
registry.TryRegister(listener.Object);
HttpContext httpContext = this.ContextWithRequestedSubprotocols("abc");
var authenticator = new Mock();
authenticator.Setup(p => p.AuthenticateAsync(It.IsAny())).ReturnsAsync(false);
IHttpProxiedCertificateExtractor certExtractor = new HttpProxiedCertificateExtractor(authenticator.Object, Mock.Of(), "hub", "edge", "proxy");
var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor));
await middleware.Invoke(httpContext);
authenticator.Verify(auth => auth.AuthenticateAsync(It.IsAny()), Times.Never);
listener.VerifyAll();
}
static IWebSocketListenerRegistry ObservingWebSocketListenerRegistry(List correlationIds)
{
var registry = new Mock();
var listener = new Mock();
listener.Setup(
wsl => wsl.ProcessWebSocketRequestAsync(
It.IsAny(),
It.IsAny(),
It.IsAny(),
It.IsAny()))
.Returns(Task.CompletedTask)
.Callback((ws, ep1, ep2, id) => correlationIds.Add(id));
registry
.Setup(wslr => wslr.GetListener(It.IsAny()))
.Returns(Option.Some(listener.Object));
return registry.Object;
}
HttpContext WebSocketRequestContext()
{
return Mock.Of(
ctx =>
ctx.WebSockets == Mock.Of(wsm => wsm.IsWebSocketRequest == true)
&& ctx.Request == Mock.Of(
req =>
req.Headers == Mock.Of())
&& ctx.Response == Mock.Of()
&& ctx.Features == Mock.Of(
fc =>
fc.Get() == Mock.Of(
f => f.ChainElements == new List()))
&& ctx.Connection == Mock.Of(
conn => conn.LocalIpAddress == new IPAddress(123)
&& conn.LocalPort == It.IsAny()
&& conn.RemoteIpAddress == new IPAddress(123)
&& conn.RemotePort == It.IsAny()
&& conn.ClientCertificate == new X509Certificate2()));
}
HttpContext NonWebSocketRequestContext()
{
return Mock.Of(
ctx =>
ctx.WebSockets == Mock.Of(
wsm =>
wsm.IsWebSocketRequest == false));
}
HttpContext ContextWithRequestedSubprotocols(params string[] subprotocols)
{
return Mock.Of(
ctx =>
ctx.WebSockets == Mock.Of(
wsm =>
wsm.WebSocketRequestedProtocols == subprotocols
&& wsm.IsWebSocketRequest
&& wsm.AcceptWebSocketAsync(It.IsAny()) == Task.FromResult(Mock.Of()))
&& ctx.Request == Mock.Of(
req =>
req.Headers == Mock.Of())
&& ctx.Response == Mock.Of()
&& ctx.Features == Mock.Of(
fc => fc.Get() == Mock.Of(f => f.ChainElements == new List()))
&& ctx.Connection == Mock.Of(
conn => conn.LocalIpAddress == new IPAddress(123)
&& conn.LocalPort == It.IsAny()
&& conn.RemoteIpAddress == new IPAddress(123) && conn.RemotePort == It.IsAny()
&& conn.ClientCertificate == new X509Certificate2()));
}
HttpContext ContextWithRequestedSubprotocolsAndForwardedCert(StringValues cert, params string[] subprotocols)
{
return Mock.Of(
ctx =>
ctx.WebSockets == Mock.Of(
wsm =>
wsm.WebSocketRequestedProtocols == subprotocols
&& wsm.IsWebSocketRequest
&& wsm.AcceptWebSocketAsync(It.IsAny()) == Task.FromResult(Mock.Of()))
&& ctx.Request == Mock.Of(
req =>
req.Headers == Mock.Of(h => h.TryGetValue("x-ms-edge-clientcert", out cert)) == true )
&& ctx.Response == Mock.Of()
&& ctx.Features == Mock.Of(
fc => fc.Get() == Mock.Of(f => f.ChainElements == new List()))
&& ctx.Connection == Mock.Of(
conn => conn.LocalIpAddress == new IPAddress(123)
&& conn.LocalPort == It.IsAny()
&& conn.RemoteIpAddress == new IPAddress(123) && conn.RemotePort == It.IsAny()
&& conn.ClientCertificate == new X509Certificate2()));
}
RequestDelegate ThrowingNextDelegate()
{
return ctx => throw new Exception("delegate 'next' should not be called");
}
IWebSocketListenerRegistry ThrowingWebSocketListenerRegistry()
{
var registry = new Mock();
registry
.Setup(wslr => wslr.GetListener(It.IsAny()))
.Throws(new Exception("IWebSocketListenerRegistry.InvokeAsync should not be called"));
return registry.Object;
}
}
}