csharp/a1q123456/Harmonic/Harmonic/Networking/Rtmp/ChunkStreamContext.cs

ChunkStreamContext.cs
using Harmonic.Buffers;
using Harmonic.Networking.Amf.Serialization.Amf0;
using Harmonic.Networking.Amf.Serialization.Amf3;
using Harmonic.Networking.Rtmp.Data;
using Harmonic.Networking.Rtmp.Messages;
using Harmonic.Networking.Rtmp.Messages.Commands;
using Harmonic.Networking.Rtmp.Serialization;
using Harmonic.Networking.Utils;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using static Harmonic.Networking.Rtmp.IOPipeLine;

namespace Harmonic.Networking.Rtmp
{
    clast ChunkStreamContext : IDisposable
    {
        private ArrayPool _arrayPool = ArrayPool.Shared;
        internal ChunkHeader _processingChunk = null;
        internal int ReadMinimumBufferSize { get => (ReadChunkSize + TYPE0_SIZE) * 4; }
        internal Dictionary _previousWriteMessageHeader = new Dictionary();
        internal Dictionary _previousReadMessageHeader = new Dictionary();
        internal Dictionary _incompleteMessageState = new Dictionary();
        internal uint? ReadWindowAcknowledgementSize { get; set; } = null;
        internal uint? WriteWindowAcknowledgementSize { get; set; } = null;
        internal int ReadChunkSize { get; set; } = 128;
        internal long ReadUnAcknowledgedSize = 0;
        internal long WriteUnAcknowledgedSize = 0;

        internal uint _writeChunkSize = 128;
        internal readonly int EXTENDED_TIMESTAMP_LENGTH = 4;
        internal readonly int TYPE0_SIZE = 11;
        internal readonly int TYPE1_SIZE = 7;
        internal readonly int TYPE2_SIZE = 3;

        internal RtmpSession _rtmpSession = null;

        internal Amf0Reader _amf0Reader = new Amf0Reader();
        internal Amf0Writer _amf0Writer = new Amf0Writer();
        internal Amf3Reader _amf3Reader = new Amf3Reader();
        internal Amf3Writer _amf3Writer = new Amf3Writer();


        private IOPipeLine _ioPipeline = null;
        private SemapreplacedSlim _sync = new SemapreplacedSlim(1);
        internal LimitType? PreviousLimitType { get; set; } = null;

        public ChunkStreamContext(IOPipeLine stream)
        {
            _rtmpSession = new RtmpSession(stream);
            _ioPipeline = stream;
            _ioPipeline.NextProcessState = ProcessState.FirstByteBasicHeader;
            _ioPipeline._bufferProcessors.Add(ProcessState.ChunkMessageHeader, ProcessChunkMessageHeader);
            _ioPipeline._bufferProcessors.Add(ProcessState.CompleteMessage, ProcessCompleteMessage);
            _ioPipeline._bufferProcessors.Add(ProcessState.ExtendedTimestamp, ProcessExtendedTimestamp);
            _ioPipeline._bufferProcessors.Add(ProcessState.FirstByteBasicHeader, ProcessFirstByteBasicHeader);
        }

        public void Dispose()
        {
            ((IDisposable)_rtmpSession).Dispose();
        }

        internal async Task MultiplexMessageAsync(uint chunkStreamId, Message message)
        {
            if (!message.MessageHeader.MessageStreamId.HasValue)
            {
                throw new InvalidOperationException("cannot send message that has not attached to a message stream");
            }
            byte[] buffer = null;
            uint length = 0;
            using (var writeBuffer = new ByteBuffer())
            {
                var context = new Serialization.SerializationContext()
                {
                    Amf0Reader = _amf0Reader,
                    Amf0Writer = _amf0Writer,
                    Amf3Reader = _amf3Reader,
                    Amf3Writer = _amf3Writer,
                    WriteBuffer = writeBuffer
                };
                message.Serialize(context);
                length = (uint)writeBuffer.Length;
                Debug.astert(length != 0);
                buffer = _arrayPool.Rent((int)length);
                writeBuffer.TakeOutMemory(buffer);
            }

            try
            {
                message.MessageHeader.MessageLength = length;
                Debug.astert(message.MessageHeader.MessageLength != 0);
                if (message.MessageHeader.MessageType == 0)
                {
                    message.MessageHeader.MessageType = message.GetType().GetCustomAttribute().MessageTypes.First();
                }
                Debug.astert(message.MessageHeader.MessageType != 0);
                Task ret = null;
                // chunking
                bool isFirstChunk = true;
                _rtmpSession.astertStreamId(message.MessageHeader.MessageStreamId.Value);
                for (int i = 0; i < message.MessageHeader.MessageLength;)
                {
                    _previousWriteMessageHeader.TryGetValue(chunkStreamId, out var prevHeader);
                    var chunkHeaderType = SelectChunkType(message.MessageHeader, prevHeader, isFirstChunk);
                    isFirstChunk = false;
                    GenerateBasicHeader(chunkHeaderType, chunkStreamId, out var basicHeader, out var basicHeaderLength);
                    GenerateMesesageHeader(chunkHeaderType, message.MessageHeader, prevHeader, out var messageHeader, out var messageHeaderLength);
                    _previousWriteMessageHeader[chunkStreamId] = (MessageHeader)message.MessageHeader.Clone();
                    var headerLength = basicHeaderLength + messageHeaderLength;
                    var bodySize = (int)(length - i >= _writeChunkSize ? _writeChunkSize : length - i);

                    var chunkBuffer = _arrayPool.Rent(headerLength + bodySize);
                    await _sync.WaitAsync();
                    try
                    {
                        basicHeader.astpan(0, basicHeaderLength).CopyTo(chunkBuffer);
                        messageHeader.astpan(0, messageHeaderLength).CopyTo(chunkBuffer.astpan(basicHeaderLength));
                        _arrayPool.Return(basicHeader);
                        _arrayPool.Return(messageHeader);
                        buffer.astpan(i, bodySize).CopyTo(chunkBuffer.astpan(headerLength));
                        i += bodySize;
                        var isLastChunk = message.MessageHeader.MessageLength - i == 0;

                        long offset = 0;
                        long totalLength = headerLength + bodySize;
                        long currentSendSize = totalLength;

                        while (offset != (headerLength + bodySize))
                        {
                            if (WriteWindowAcknowledgementSize.HasValue && Interlocked.Read(ref WriteUnAcknowledgedSize) + headerLength + bodySize > WriteWindowAcknowledgementSize.Value)
                            {
                                currentSendSize = Math.Min(WriteWindowAcknowledgementSize.Value, currentSendSize);
                                //var delayCount = 0;
                                while (currentSendSize + Interlocked.Read(ref WriteUnAcknowledgedSize) >= WriteWindowAcknowledgementSize.Value)
                                {
                                    await Task.Delay(1);
                                }
                            }
                            var tsk = _ioPipeline.SendRawData(chunkBuffer.AsMemory((int)offset, (int)currentSendSize));
                            offset += currentSendSize;
                            totalLength -= currentSendSize;

                            if (WriteWindowAcknowledgementSize.HasValue)
                            {
                                Interlocked.Add(ref WriteUnAcknowledgedSize, currentSendSize);
                            }
                            
                            if (isLastChunk)
                            {
                                ret = tsk;
                            }
                        }
                        if (isLastChunk)
                        {
                            if (message.MessageHeader.MessageType == MessageType.SetChunkSize)
                            {
                                var setChunkSize = message as SetChunkSizeMessage;
                                _writeChunkSize = setChunkSize.ChunkSize;
                            }
                            else if (message.MessageHeader.MessageType == MessageType.SetPeerBandwidth)
                            {
                                var m = message as SetPeerBandwidthMessage;
                                ReadWindowAcknowledgementSize = m.WindowSize;
                            }
                            else if (message.MessageHeader.MessageType == MessageType.WindowAcknowledgementSize)
                            {
                                var m = message as WindowAcknowledgementSizeMessage;
                                WriteWindowAcknowledgementSize = m.WindowSize;
                            }
                        }
                    }
                    finally
                    {
                        _sync.Release();
                        _arrayPool.Return(chunkBuffer);
                    }
                }
                Debug.astert(ret != null);
                await ret;

            }
            finally
            {
                _arrayPool.Return(buffer);
            }

        }

        private void GenerateMesesageHeader(ChunkHeaderType chunkHeaderType, MessageHeader header, MessageHeader prevHeader, out byte[] buffer, out int length)
        {
            var timestamp = header.Timestamp;
            switch (chunkHeaderType)
            {
                case ChunkHeaderType.Type0:
                    buffer = _arrayPool.Rent(TYPE0_SIZE + EXTENDED_TIMESTAMP_LENGTH);
                    NetworkBitConverter.TryGetUInt24Bytes(timestamp >= 0xFFFFFF ? 0xFFFFFF : timestamp, buffer.astpan(0, 3));
                    NetworkBitConverter.TryGetUInt24Bytes(header.MessageLength, buffer.astpan(3, 3));
                    NetworkBitConverter.TryGetBytes((byte)header.MessageType, buffer.astpan(6, 1));
                    NetworkBitConverter.TryGetBytes(header.MessageStreamId.Value, buffer.astpan(7, 4), true);
                    length = TYPE0_SIZE;
                    break;
                case ChunkHeaderType.Type1:
                    buffer = _arrayPool.Rent(TYPE1_SIZE + EXTENDED_TIMESTAMP_LENGTH);
                    timestamp = timestamp - prevHeader.Timestamp;
                    NetworkBitConverter.TryGetUInt24Bytes(timestamp >= 0xFFFFFF ? 0xFFFFFF : timestamp, buffer.astpan(0, 3));
                    NetworkBitConverter.TryGetUInt24Bytes(header.MessageLength, buffer.astpan(3, 3));
                    NetworkBitConverter.TryGetBytes((byte)header.MessageType, buffer.astpan(6, 1));
                    length = TYPE1_SIZE;
                    break;
                case ChunkHeaderType.Type2:
                    buffer = _arrayPool.Rent(TYPE2_SIZE + EXTENDED_TIMESTAMP_LENGTH);
                    timestamp = timestamp - prevHeader.Timestamp;
                    NetworkBitConverter.TryGetUInt24Bytes(timestamp >= 0xFFFFFF ? 0xFFFFFF : timestamp, buffer.astpan(0, 3));
                    length = TYPE2_SIZE;
                    break;
                case ChunkHeaderType.Type3:
                    buffer = _arrayPool.Rent(EXTENDED_TIMESTAMP_LENGTH);
                    length = 0;
                    break;
                default:
                    throw new ArgumentException();
            }
            if (timestamp >= 0xFFFFFF)
            {
                NetworkBitConverter.TryGetBytes(timestamp, buffer.astpan(length, EXTENDED_TIMESTAMP_LENGTH));
                length += EXTENDED_TIMESTAMP_LENGTH;
            }
        }

        private void GenerateBasicHeader(ChunkHeaderType chunkHeaderType, uint chunkStreamId, out byte[] buffer, out int length)
        {
            byte fmt = (byte)chunkHeaderType;
            if (chunkStreamId >= 2 && chunkStreamId  6);
            header.ChunkBasicHeader.ChunkStreamId = (uint)basicHeader & 0x3F;
            if (header.ChunkBasicHeader.ChunkStreamId != 0 && header.ChunkBasicHeader.ChunkStreamId != 0x3F)
            {
                if (header.ChunkBasicHeader.RtmpChunkHeaderType == ChunkHeaderType.Type3)
                {
                    FillHeader(header);
                    _ioPipeline.NextProcessState = ProcessState.CompleteMessage;
                    return true;
                }
            }
            _ioPipeline.NextProcessState = ProcessState.ChunkMessageHeader;
            return true;
        }

        private bool ProcessChunkMessageHeader(ReadOnlySequence buffer, ref int consumed)
        {
            int bytesNeed = 0;
            switch (_processingChunk.ChunkBasicHeader.ChunkStreamId)
            {
                case 0:
                    bytesNeed = 1;
                    break;
                case 0x3F:
                    bytesNeed = 2;
                    break;
            }
            switch (_processingChunk.ChunkBasicHeader.RtmpChunkHeaderType)
            {
                case ChunkHeaderType.Type0:
                    bytesNeed += TYPE0_SIZE;
                    break;
                case ChunkHeaderType.Type1:
                    bytesNeed += TYPE1_SIZE;
                    break;
                case ChunkHeaderType.Type2:
                    bytesNeed += TYPE2_SIZE;
                    break;
            }

            if (buffer.Length - consumed = ReadChunkSize ? ReadChunkSize : state.RemainBytes);

            if (buffer.Length - consumed < bytesNeed)
            {
                return false;
            }

            if (_previousReadMessageHeader.TryGetValue(header.ChunkBasicHeader.ChunkStreamId, out var prevHeader))
            {
                if (prevHeader.MessageStreamId != header.MessageHeader.MessageStreamId)
                {
                    // inform user previous message will never be received
                    prevHeader = null;
                }
            }
            _previousReadMessageHeader[_processingChunk.ChunkBasicHeader.ChunkStreamId] = (MessageHeader)_processingChunk.MessageHeader.Clone();
            _processingChunk = null;

            buffer.Slice(consumed, bytesNeed).CopyTo(state.Body.astpan(state.CurrentIndex));
            consumed += bytesNeed;
            state.CurrentIndex += bytesNeed;

            if (state.IsCompleted)
            {
                _incompleteMessageState.Remove(header.ChunkBasicHeader.ChunkStreamId);
                try
                {
                    var context = new Serialization.SerializationContext()
                    {
                        Amf0Reader = _amf0Reader,
                        Amf0Writer = _amf0Writer,
                        Amf3Reader = _amf3Reader,
                        Amf3Writer = _amf3Writer,
                        ReadBuffer = state.Body.AsMemory(0, (int)state.MessageLength)
                    };
                    if (header.MessageHeader.MessageType == MessageType.AggregateMessage)
                    {
                        var agg = new AggregateMessage()
                        {
                            MessageHeader = header.MessageHeader
                        };
                        agg.Deserialize(context);
                        foreach (var message in agg.Messages)
                        {
                            if (!_ioPipeline.Options.MessageFactories.TryGetValue(message.Header.MessageType, out var factory))
                            {
                                continue;
                            }
                            var msgContext = new Serialization.SerializationContext()
                            {
                                Amf0Reader = context.Amf0Reader,
                                Amf3Reader = context.Amf3Reader,
                                Amf0Writer = context.Amf0Writer,
                                Amf3Writer = context.Amf3Writer,
                                ReadBuffer = context.ReadBuffer.Slice(message.DataOffset, (int)message.DataLength)
                            };
                            try
                            {
                                var msg = factory(header.MessageHeader, msgContext, out var factoryConsumed);
                                msg.MessageHeader = header.MessageHeader;
                                msg.Deserialize(msgContext);
                                context.Amf0Reader.ResetReference();
                                context.Amf3Reader.ResetReference();
                                _rtmpSession.MessageArrived(msg);
                            }
                            catch (NotSupportedException)
                            {

                            }
                        }
                    }
                    else
                    {
                        if (_ioPipeline.Options._messageFactories.TryGetValue(header.MessageHeader.MessageType, out var factory))
                        {
                            try
                            {
                                var message = factory(header.MessageHeader, context, out var factoryConsumed);
                                message.MessageHeader = header.MessageHeader;
                                context.ReadBuffer = context.ReadBuffer.Slice(factoryConsumed);
                                message.Deserialize(context);
                                context.Amf0Reader.ResetReference();
                                context.Amf3Reader.ResetReference();
                                _rtmpSession.MessageArrived(message);
                            }
                            catch (NotSupportedException)
                            {

                            }
                        }
                    }
                }
                finally
                {
                    _arrayPool.Return(state.Body);
                }
            }
            _ioPipeline.NextProcessState = ProcessState.FirstByteBasicHeader;
            return true;
        }

    }
}