csharp/a1q123456/Harmonic/Harmonic/Networking/Amf/Serialization/Amf0/Amf0Reader.cs

Amf0Reader.cs
using Harmonic.Networking.Amf.Attributes;
using Harmonic.Networking.Amf.Common;
using Harmonic.Networking.Amf.Data;
using Harmonic.Networking.Amf.Serialization.Amf3;
using Harmonic.Networking.Amf.Serialization.Attributes;
using Harmonic.Networking.Utils;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Xml;

namespace Harmonic.Networking.Amf.Serialization.Amf0
{
    public clast Amf0Reader
    {
        public readonly IReadOnlyDictionary TypeLengthMap = new Dictionary()
        {
            { Amf0Type.Number, 8 },
            { Amf0Type.Boolean, sizeof(byte) },
            { Amf0Type.String, Amf0CommonValues.STRING_HEADER_LENGTH },
            { Amf0Type.LongString, Amf0CommonValues.LONG_STRING_HEADER_LENGTH },
            { Amf0Type.Object, /* object marker*/ Amf0CommonValues.MARKER_LENGTH - /* utf8-empty */Amf0CommonValues.STRING_HEADER_LENGTH - /* object end marker */Amf0CommonValues.MARKER_LENGTH },
            { Amf0Type.Null, 0 },
            { Amf0Type.Undefined, 0 },
            { Amf0Type.Reference, sizeof(ushort) },
            { Amf0Type.EcmaArray, sizeof(uint) },
            { Amf0Type.StrictArray, sizeof(uint) },
            { Amf0Type.Date, 10 },
            { Amf0Type.Unsupported, 0 },
            { Amf0Type.XmlDocameent, 0 },
            { Amf0Type.TypedObject, /* object marker*/ Amf0CommonValues.MARKER_LENGTH - /* clast name */ Amf0CommonValues.STRING_HEADER_LENGTH - /* at least on character for clast name */ 1 - /* utf8-empty */Amf0CommonValues.STRING_HEADER_LENGTH - /* object end marker */Amf0CommonValues.MARKER_LENGTH },
            { Amf0Type.AvmPlusObject, 0 },
            { Amf0Type.ObjectEnd, 0 }
        };

        private delegate bool ReadDataHandler(Span buffer, out T data, out int consumedLength);
        private delegate bool ReadDataHandler(Span buffer, out object data, out int consumedLength);

        private List _registeredTypes = new List();
        public IReadOnlyList RegisteredTypes { get; }
        private IReadOnlyDictionary _readDataHandlers;
        private Dictionary _registeredTypeStates = new Dictionary();
        private List _referenceTable = new List();
        private Amf3.Amf3Reader _amf3Reader = new Amf3.Amf3Reader();
        public bool StrictMode { get; set; } = true;

        public Amf0Reader()
        {
            var readDataHandlers = new Dictionary
            {
                [Amf0Type.Number] = OutValueTypeEraser(TryGetNumber),
                [Amf0Type.Boolean] = OutValueTypeEraser(TryGetBoolean),
                [Amf0Type.String] = OutValueTypeEraser(TryGetString),
                [Amf0Type.Object] = OutValueTypeEraser(TryGetObject),
                [Amf0Type.Null] = OutValueTypeEraser(TryGetNull),
                [Amf0Type.Undefined] = OutValueTypeEraser(TryGetUndefined),
                [Amf0Type.Reference] = OutValueTypeEraser(TryGetReference),
                [Amf0Type.EcmaArray] = OutValueTypeEraser(TryGetEcmaArray),
                [Amf0Type.StrictArray] = OutValueTypeEraser(TryGetStrictArray),
                [Amf0Type.Date] = OutValueTypeEraser(TryGetDate),
                [Amf0Type.LongString] = OutValueTypeEraser(TryGetLongString),
                [Amf0Type.Unsupported] = OutValueTypeEraser(TryGetUnsupported),
                [Amf0Type.XmlDocameent] = OutValueTypeEraser(TryGetXmlDocameent),
                [Amf0Type.TypedObject] = OutValueTypeEraser(TryGetTypedObject),
                [Amf0Type.AvmPlusObject] = OutValueTypeEraser(TryGetAvmPlusObject)
            };
            _readDataHandlers = readDataHandlers;
        }

        public void ResetReference()
        {
            _referenceTable.Clear();
        }
        public void RegisterType() where T : new()
        {
            var type = typeof(T);
            var props = type.GetProperties();
            var fields = props.Where(p => p.CanWrite && Attribute.GetCustomAttribute(p, typeof(ClastFieldAttribute)) != null).ToList();
            var members = fields.ToDictionary(p => ((ClastFieldAttribute)Attribute.GetCustomAttribute(p, typeof(ClastFieldAttribute))).Name ?? p.Name, p => new Action(p.SetValue));
            if (members.Keys.Where(s => string.IsNullOrEmpty(s)).Any())
            {
                throw new InvalidOperationException("Field name cannot be empty or null");
            }
            string mapedName = null;
            var attr = type.GetCustomAttribute();
            if (attr != null)
            {
                mapedName = attr.Name;
            }
            var typeName = mapedName == null ? type.Name : mapedName;
            var state = new TypeRegisterState()
            {
                Members = members,
                Type = type
            };
            _registeredTypes.Add(type);
            _registeredTypeStates.Add(typeName, state);
            _amf3Reader.RegisterTypedObject(typeName, state);
        }

        public void RegisterIExternalizableForAvmPlus() where T : IExternalizable, new()
        {
            _amf3Reader.RegisterExternalizable();
        }

        private ReadDataHandler OutValueTypeEraser(ReadDataHandler handler)
        {
            return (Span b, out object d, out int c) =>
            {
                var ret = handler(b, out var n, out c);
                d = n;
                return ret;
            };
        }

        private bool TryReadHeader(Span buffer, out KeyValuePair header, out int consumed)
        {
            header = default;
            consumed = 0;
            if (!TryGetStringImpl(buffer, Amf0.Amf0CommonValues.STRING_HEADER_LENGTH, out var headerName, out var nameConsumed))
            {
                return false;
            }

            buffer = buffer.Slice(nameConsumed);
            if (buffer.Length < 1)
            {
                return false;
            }
            var mustUnderstand = buffer[0];
            buffer = buffer.Slice(1);
            if (buffer.Length < sizeof(uint))
            {
                return false;
            }

            buffer = buffer.Slice(sizeof(uint));
            if (!TryGetValue(buffer, out _, out var headerValue, out var valueConsumed))
            {
                return false;
            }
            header = new KeyValuePair(headerName, headerValue);
            consumed = nameConsumed + 1 + sizeof(uint) + valueConsumed;
            return true;
        }

        public bool TryGetMessage(Span buffer, out Message message, out int consumed)
        {
            message = default;
            consumed = default;

            if (!TryGetStringImpl(buffer, Amf0CommonValues.STRING_HEADER_LENGTH, out var targetUri, out var targetUriConsumed))
            {
                return false;
            }

            buffer = buffer.Slice(targetUriConsumed);
            if (!TryGetStringImpl(buffer, Amf0CommonValues.STRING_HEADER_LENGTH, out var responseUri, out var responseUriConsumed))
            {
                return false;
            }

            buffer = buffer.Slice(responseUriConsumed);
            if (buffer.Length < sizeof(uint))
            {
                return false;
            }
            var messageLength = NetworkBitConverter.ToUInt32(buffer);
            if (messageLength >= 0 && buffer.Length < messageLength)
            {
                return false;
            }
            if (messageLength == 0 && StrictMode)
            {
                return true;
            }
            buffer = buffer.Slice(sizeof(uint));
            if (!TryGetValue(buffer, out _, out var content, out var contentConsumed))
            {
                return false;
            }
            consumed = targetUriConsumed + responseUriConsumed + sizeof(uint) + contentConsumed;
            message = new Message()
            {
                TargetUri = targetUri,
                ResponseUri = responseUri,
                Content = content
            };
            return true;
        }

        public bool TryGetPacket(Span buffer, out List headers, out List messages, out int consumed)
        {
            headers = default;
            messages = default;
            consumed = 0;

            if (buffer.Length < 1)
            {
                return false;
            }
            var version = NetworkBitConverter.ToUInt16(buffer);
            buffer = buffer.Slice(sizeof(ushort));
            consumed += sizeof(ushort);
            var headerCount = NetworkBitConverter.ToUInt16(buffer);
            buffer = buffer.Slice(sizeof(ushort));
            consumed += sizeof(ushort);
            headers = new List();
            messages = new List();
            for (int i = 0; i < headerCount; i++)
            {
                if (!TryReadHeader(buffer, out var header, out var headerConsumed))
                {
                    return false;
                }
                headers.Add(header);
                buffer = buffer.Slice(headerConsumed);
                consumed += headerConsumed;
            }

            var messageCount = NetworkBitConverter.ToUInt16(buffer);
            buffer = buffer.Slice(sizeof(ushort));
            consumed += sizeof(ushort);
            for (int i = 0; i < messageCount; i++)
            {
                if (!TryGetMessage(buffer, out var message, out var messageConsumed))
                {
                    return false;
                }
                messages.Add(message);
                consumed += messageConsumed;
            }
            return true;
        }

        public bool TryDescribeData(Span buffer, out Amf0Type type, out int consumedLength)
        {
            type = default;
            consumedLength = default;
            if (buffer.Length < Amf0CommonValues.MARKER_LENGTH)
            {
                return false;
            }

            var marker = (Amf0Type)buffer[0];
            if (!TypeLengthMap.TryGetValue(marker, out var bytesNeed))
            {
                return false;
            }
            if (buffer.Length - Amf0CommonValues.MARKER_LENGTH < bytesNeed)
            {
                return false;
            }

            type = marker;
            consumedLength = (int)bytesNeed + Amf0CommonValues.MARKER_LENGTH;

            return true;
        }

        public bool TryGetNumber(Span buffer, out double value, out int bytesConsumed)
        {
            value = default;
            bytesConsumed = default;
            if (!TryDescribeData(buffer, out var type, out var length))
            {
                return false;
            }
            if (type != Amf0Type.Number)
            {
                return false;
            }
            value = NetworkBitConverter.ToDouble(buffer.Slice(Amf0CommonValues.MARKER_LENGTH));
            bytesConsumed = length;
            return true;
        }

        public bool TryGetBoolean(Span buffer, out bool value, out int bytesConsumed)
        {
            value = default;
            bytesConsumed = default;

            if (!TryDescribeData(buffer, out var type, out var length))
            {
                return false;
            }

            if (type != Amf0Type.Boolean)
            {
                return false;
            }

            value = buffer[1] != 0;
            bytesConsumed = length;
            return true;
        }
        public bool TryGetString(Span buffer, out string value, out int bytesConsumed)
        {
            value = default;
            bytesConsumed = default;

            if (!TryDescribeData(buffer, out var type, out _))
            {
                return false;
            }

            if (type != Amf0Type.String)
            {
                return false;
            }

            if (!TryGetStringImpl(buffer.Slice(Amf0CommonValues.MARKER_LENGTH), Amf0CommonValues.STRING_HEADER_LENGTH, out value, out bytesConsumed))
            {
                return false;
            }

            bytesConsumed += Amf0CommonValues.MARKER_LENGTH;
            _referenceTable.Add(value);
            return true;
        }

        private bool TryGetObjectImpl(Span objectBuffer, out Dictionary value, out int bytesConsumed)
        {
            value = default;
            bytesConsumed = default;
            var obj = new Dictionary();
            _referenceTable.Add(obj);
            var consumed = 0;
            while (true)
            {
                if (!TryGetStringImpl(objectBuffer, Amf0CommonValues.STRING_HEADER_LENGTH, out var key, out var keyLength))
                {
                    return false;
                }
                consumed += keyLength;
                objectBuffer = objectBuffer.Slice(keyLength);

                if (!TryGetValue(objectBuffer, out var dataType, out var data, out var valueLength))
                {
                    return false;
                }
                consumed += valueLength;
                objectBuffer = objectBuffer.Slice(valueLength);

                if (!key.Any() && dataType == Amf0Type.ObjectEnd)
                {
                    break;
                }
                obj.Add(key, data);
            }
            value = obj;
            bytesConsumed = consumed;
            return true;
        }

        public bool TryGetObject(Span buffer, out AmfObject value, out int bytesConsumed)
        {
            value = default;
            bytesConsumed = default;

            if (!TryDescribeData(buffer, out var type, out _))
            {
                return false;
            }

            if (type == Amf0Type.Null)
            {
                if (!TryGetNull(buffer, out _, out bytesConsumed))
                {
                    return false;
                }
                value = null;
                return true;
            }

            if (type != Amf0Type.Object)
            {
                return false;
            }

            var objectBuffer = buffer.Slice(Amf0CommonValues.MARKER_LENGTH);

            if (!TryGetObjectImpl(objectBuffer, out var obj, out var consumed))
            {
                return false;
            }

            value = new AmfObject(obj);
            bytesConsumed = consumed + Amf0CommonValues.MARKER_LENGTH;


            return true;
        }

        public bool TryGetNull(Span buffer, out object value, out int bytesConsumed)
        {
            value = default;
            bytesConsumed = default;
            if (!TryDescribeData(buffer, out var type, out var length))
            {
                return false;
            }

            if (type != Amf0Type.Null)
            {
                return false;
            }
            value = null;
            bytesConsumed = Amf0CommonValues.MARKER_LENGTH;
            return true;
        }

        public bool TryGetUndefined(Span buffer, out Undefined value, out int consumedLength)
        {
            value = default;
            consumedLength = default;
            if (!TryDescribeData(buffer, out var type, out var length))
            {
                return false;
            }

            if (type != Amf0Type.Undefined)
            {
                return false;
            }
            value = new Undefined();
            consumedLength = Amf0CommonValues.MARKER_LENGTH;
            return true;
        }

        private bool TryGetReference(Span buffer, out object value, out int consumedLength)
        {
            var index = 0;
            value = default;
            consumedLength = default;
            if (!TryDescribeData(buffer, out var type, out var length))
            {
                return false;
            }

            if (type != Amf0Type.Reference)
            {
                return false;
            }

            index = NetworkBitConverter.ToUInt16(buffer.Slice(Amf0CommonValues.MARKER_LENGTH, sizeof(ushort)));
            consumedLength = Amf0CommonValues.MARKER_LENGTH + sizeof(ushort);
            if (_referenceTable.Count