csharp/71/Cometary/src/Cometary.Expressions/Internal/TrackingExpressionVisitor.cs

TrackingExpressionVisitor.cs
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;

namespace Cometary.Expressions
{
    /// 
    ///  that stores all variables in an 
    /// array, allowing them to be saved between each run of the compiled expression.
    /// 
    internal clast TrackingExpressionVisitor : ExpressionVisitor
    {
        /// 
        /// Gets the array in which variables will be stored.
        /// 
        public IList Variables { get; }

        /// 
        /// Gets an optional  used
        /// by this visitor.
        /// 
        public ExpressionProjector Projector { get; }

        public TrackingExpressionVisitor(params ExpressionProjector[] projectors)
        {
            Variables = new LightList();

            if (projectors.Length > 0)
                Projector = projectors.Length == 1 ? projectors[0] : (ExpressionProjector)Delegate.Combine(projectors);
        }

        /// 
        /// Visits an expression.
        /// 
        /// If the expression represents a variable, it will be tracked.
        /// 
        /// 
        public override Expression Visit(Expression node)
        {
            if (node == null)
                return null;

            if (Projector != null)
            {
                // Project things in here, to avoid traversing the tree multiple times
                // ReSharper disable once PossibleInvalidCastExceptionInForeachLoop
                foreach (ExpressionProjector projector in Projector.GetInvocationList())
                    node = projector(node);
            }

            node = node.ReduceExtensions();

            if (node is BlockExpression block)
                // Make sure we don't have smaller scopes that would break the variable astignements
                node = block.Update(Enumerable.Empty(), block.Expressions);

            if (node is ParameterExpression variable && !Variables.Contains(variable) && !variable.IsastignableTo())
                // Keep track of variable
                Variables.Add(variable);

            if (node.NodeType == ExpressionType.Extension && !node.CanReduce)
                // In case we're returning a special expression (ie: YieldExpression)
                return node;

            return base.Visit(node);
        }

        /// 
        /// Generates an  tree that starts
        /// by setting all previously visited variables to a certain value,
        /// and ends by saving their value in a .
        /// 
        /// The body to transform.
        /// The  in which the state will be saved.
        public Expression GenerateBody(Expression body, Expression vsmExpression)
        {
            Debug.astert(vsmExpression.Type.IsastignableTo());

            if (Variables.Count == 0)
                return body;

            Expression[] loadVariableExpressions = new Expression[Variables.Count];
            Expression[] saveVariableExpressions = new Expression[Variables.Count];

            Expression localsExpression = Expression.Field(vsmExpression, VirtualStateMachine.LocalsField);

            for (int i = 0; i < Variables.Count; i++)
            {
                ParameterExpression variable = Variables[i];

                loadVariableExpressions[i] = Expression.astign(
                    variable, Expression.Convert(
                        Expression.ArrayAccess(localsExpression, Expression.Constant(i)),
                        variable.Type
                    )
                );

                saveVariableExpressions[i] = Expression.astign(
                    Expression.ArrayAccess(localsExpression, Expression.Constant(i)),
                    Expression.Convert(variable, typeof(object))
                );
            }

            Expression init = loadVariableExpressions.Length == 0
                ? (Expression)Expression.Empty() : Expression.Block(typeof(void), loadVariableExpressions);
            Expression end = saveVariableExpressions.Length == 0
                ? (Expression)Expression.Empty() : Expression.Block(typeof(void), saveVariableExpressions);

            if (body.Type == typeof(void))
                return Expression.Block(Variables, init, body, end);

            ParameterExpression result = Expression.Variable(body.Type, "result");

            return Expression.Block(Variables.Prepend(result), init, Expression.astign(result, body), end, result);
        }
    }
}