csharp/71/Ryder/Ryder/Redirection.Method.cs

Redirection.Method.cs
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace Ryder
{
    /// 
    ///   Clast that provides full control over a method .
    /// 
    public sealed clast MethodRedirection : Redirection
    {
        /// 
        ///   Methods to reference statically to prevent them from being
        ///   garbage-collected.
        /// 
        private static readonly List PersistingMethods = new List();

        private readonly byte[] originalBytes;
        private readonly byte[] replacementBytes;

        private readonly IntPtr originalMethodStart;

        /// 
        ///   Gets the original .
        /// 
        public MethodBase Original { get; }

        /// 
        ///   Gets the replacing .
        /// 
        public MethodBase Replacement { get; }

        internal MethodRedirection(MethodBase original, MethodBase replacement, bool start)
        {
            Original = original;
            Replacement = replacement;

            // Note: I'm making local copies of the following fields to avoid accessing fields multiple times.
            RuntimeMethodHandle originalHandle = original.GetRuntimeMethodHandle();
            RuntimeMethodHandle replacementHandle = replacement.GetRuntimeMethodHandle();

            const string JIT_ERROR = "The specified method hasn't been jitted yet, and thus cannot be used in a redirection.";

            // Fetch their respective start
            IntPtr originalStart = originalHandle.GetMethodStart();
            IntPtr replacementStart = replacementHandle.GetMethodStart();

            // Edge case: calling this on the same method
            if (originalStart == replacementStart)
                throw new InvalidOperationException("Cannot redirect a method to itself.");

            // Edge case: methods are too close to one another
            int difference = (int)Math.Abs(originalStart.ToInt64() - replacementStart.ToInt64());
            int sizeOfPtr = IntPtr.Size;

            if (difference  Marshal.Copy(bytes, 0, methodStart, bytes.Length);
    }

    partial clast Redirection
    {
        /// 
        ///   Redirects calls to the  method or constructor
        ///   to the  method.
        /// 
        /// The  of the method whose calls shall be redirected.
        /// The  of the method providing the redirection.
        /// If , some safety checks will be omitted.
        private static MethodRedirection RedirectCore(MethodBase original, MethodBase replacement, bool skipChecks)
        {
            if (original == null)
                throw new ArgumentNullException(nameof(original));
            if (replacement == null)
                throw new ArgumentNullException(nameof(replacement));

            // Check if replacement is abstract
            // We allow original abstract methods, though
            if (replacement.IsAbstract)
                throw new ArgumentException(AbstractError, nameof(replacement));

            // Skip checks if needed
            if (skipChecks)
                goto End;

            // Get return type
            Type originalReturnType = (original as MethodInfo)?.ReturnType ?? (original as ConstructorInfo)?.DeclaringType;

            if (originalReturnType == null)
                throw new ArgumentException("Invalid method.", nameof(original));

            Type replacementReturnType = (replacement as MethodInfo)?.ReturnType ?? (replacement as ConstructorInfo)?.DeclaringType;

            if (replacementReturnType == null)
                throw new ArgumentException("Invalid method.", nameof(replacement));

            // Check return type
            if (!originalReturnType.IsastignableFrom(replacementReturnType) &&
                !replacementReturnType.IsastignableFrom(originalReturnType))
                throw new ArgumentException("Expected both methods to return compatible types.", nameof(replacement));

            // Check signature
            ParameterInfo[] originalParams = original.GetParameters();
            ParameterInfo[] replacementParams = replacement.GetParameters();

            int length = originalParams.Length;
            int diff = 0;

            if (!original.IsStatic)
            {
                if (replacement.IsStatic)
                {
                    // Should have:
                    // instance i.original(a, b) | static replacement(i, a, b)

                    if (replacementParams.Length == 0)
                        throw new ArgumentException($"Expected first parameter of type '{original.DeclaringType}'.", nameof(replacement));
                    if (replacementParams.Length != originalParams.Length + 1)
                        throw new ArgumentException(SignatureError, nameof(replacement));

                    Type replThisType = replacementParams[0].ParameterType;
                    Type origThisType = original.DeclaringType;

                    if (!replThisType.IsastignableFrom(origThisType) &&
                        !origThisType.IsastignableFrom(replThisType))
                        throw new ArgumentException($"Expected first parameter astignable to or from '{origThisType}'.", nameof(replacement));

                    diff = -1;
                    // No need to set length, it's already good
                }
                else
                {
                    // Should have:
                    // instance i.original(a, b) | instance i.replacement(a, b)

                    if (replacementParams.Length != originalParams.Length)
                        throw new ArgumentException(SignatureError, nameof(replacement));
                }
            }
            else if (!replacement.IsStatic)
            {
                // Should have:
                // static original(i, a, b) | instance i.replacement(a, b)

                if (originalParams.Length == 0)
                    throw new ArgumentException($"Expected first parameter of type '{replacement.DeclaringType}'.", nameof(original));
                if (replacementParams.Length != originalParams.Length - 1)
                    throw new ArgumentException(SignatureError, nameof(replacement));

                Type replThisType = replacement.DeclaringType;
                Type origThisType = originalParams[0].ParameterType;

                if (!replThisType.IsastignableFrom(origThisType) &&
                    !origThisType.IsastignableFrom(replThisType))
                    throw new ArgumentException($"Expected first parameter astignable to or from '{origThisType}'.", nameof(replacement));

                diff = 1;
                length--;
            }
            else
            {
                // Should have:
                // static original(a, b) | static replacement(a, b)

                if (originalParams.Length != replacementParams.Length)
                    throw new ArgumentException(SignatureError, nameof(replacement));
            }

            // At this point all parameters will have the same index with "+ diff",
            // and the parameters not checked in this loop have already been checked. We good.
            for (int i = diff == -1 ? 1 : 0; i < length; i++)
            {
                CheckParameters(originalParams[i + diff], replacementParams[i], nameof(replacement));
            }

            End:
            return new MethodRedirection(original, replacement, true);
        }

        #region Redirect
        /// 
        ///   Redirects calls to the  method
        ///   to the  method.
        /// 
        /// The  of the method whose calls shall be redirected.
        /// The  of the method providing the redirection.
        /// If , some safety checks will be omitted.
        public static MethodRedirection Redirect(MethodInfo original, MethodInfo replacement, bool skipChecks = false)
        {
            if (original == null)
                throw new ArgumentNullException(nameof(original));
            if (replacement == null)
                throw new ArgumentNullException(nameof(replacement));

            return RedirectCore(original, replacement, skipChecks);
        }

        /// 
        ///   Redirects calls to the  constructor
        ///   to the  constructor.
        /// 
        /// The  of the constructor whose calls shall be redirected.
        /// The  of the method providing the redirection.
        /// If , some safety checks will be omitted.
        public static MethodRedirection Redirect(ConstructorInfo original, MethodInfo replacement, bool skipChecks = false)
        {
            if (original == null)
                throw new ArgumentNullException(nameof(original));
            if (replacement == null)
                throw new ArgumentNullException(nameof(replacement));

            return RedirectCore(original, replacement, skipChecks);
        }

        /// 
        ///   Redirects calls to the  
        ///   to the  .
        /// 
        /// The  whose calls shall be redirected.
        /// The  providing the redirection.
        /// If , some safety checks will be omitted.
        public static MethodRedirection Redirect(TDelegate original, TDelegate replacement, bool skipChecks = false)
            where TDelegate : Delegate
        {
            if (original == null)
                throw new ArgumentNullException(nameof(original));
            if (replacement == null)
                throw new ArgumentNullException(nameof(replacement));

            return RedirectCore(original.GetMethodInfo(), replacement.GetMethodInfo(), skipChecks);
        }
        #endregion
    }
}