csharp/albyho/Tubumu/src/Tubumu.Core/Extensions/QueryableExtensions.cs

QueryableExtensions.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Tubumu.Modules.Core.Models;

namespace Tubumu.Core.Extensions
{
    /// 
    /// QueryableExtensions
    /// 
    public static clast QueryableExtensions
    {
        #region Like

        /*
        public static Expression Like(Expression expr, string likeValue)
        {
            var paramExpr = expr.Parameters.First();
            var memExpr = expr.Body;

            if (likeValue == null || likeValue.Contains('%') != true)
            {
                Expression valExpr = () => likeValue;
                var eqExpr = Expression.Equal(memExpr, valExpr.Body);
                return Expression.Lambda(eqExpr, paramExpr);
            }

            if (likeValue.Replace("%", string.Empty).Length == 0)
            {
                return PredicateBuilder.True();
            }

            likeValue = Regex.Replace(likeValue, "%+", "%");

            if (likeValue.Length > 2 && likeValue.Substring(1, likeValue.Length - 2).Contains('%'))
            {
                likeValue = likeValue.Replace("[", "[[]").Replace("_", "[_]");
                Expression valExpr = () => likeValue;
                var patExpr = Expression.Call(typeof(SqlFunctions).GetMethod("PatIndex",
                    new[] { typeof(string), typeof(string) }), valExpr.Body, memExpr);
                var neExpr = Expression.NotEqual(patExpr, Expression.Convert(Expression.Constant(0), typeof(int?)));
                return Expression.Lambda(neExpr, paramExpr);
            }

            if (likeValue.StartsWith("%"))
            {
                if (likeValue.EndsWith("%") == true)
                {
                    likeValue = likeValue.Substring(1, likeValue.Length - 2);
                    Expression valExpr = () => likeValue;
                    var containsExpr = Expression.Call(memExpr, typeof(String).GetMethod("Contains",
                        new[] { typeof(string) }), valExpr.Body);
                    return Expression.Lambda(containsExpr, paramExpr);
                }
                else
                {
                    likeValue = likeValue.Substring(1);
                    Expression valExpr = () => likeValue;
                    var endsExpr = Expression.Call(memExpr, typeof(String).GetMethod("EndsWith",
                        new[] { typeof(string) }), valExpr.Body);
                    return Expression.Lambda(endsExpr, paramExpr);
                }
            }
            else
            {
                likeValue = likeValue.Remove(likeValue.Length - 1);
                Expression valExpr = () => likeValue;
                var startsExpr = Expression.Call(memExpr, typeof(String).GetMethod("StartsWith",
                    new[] { typeof(string) }), valExpr.Body);
                return Expression.Lambda(startsExpr, paramExpr);
            }
        }

        public static Expression AndLike(this Expression predicate, Expression expr, string likeValue)
        {
            var andPredicate = Like(expr, likeValue);
            if (andPredicate != null)
            {
                predicate = predicate.And(andPredicate.Expand());
            }
            return predicate;
        }

        public static Expression OrLike(this Expression predicate, Expression expr, string likeValue)
        {
            var orPredicate = Like(expr, likeValue);
            if (orPredicate != null)
            {
                predicate = predicate.Or(orPredicate.Expand());
            }
            return predicate;
        }

        */

        #endregion

        /// 
        /// WhereOrContains
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IQueryable WhereOrStringContains
            (
            this IQueryable query,
            Expression selector,
            IEnumerable values
            )
        {
            /*
             * 实现效果:
             * var tags = new[] { "A", "B", "C" };
             * SELECT * FROM [User] Where Name='Test' AND (Tags LIKE '%A%' Or Tags LIKE  '%B%')
             */

            if (selector == null)
            {
                throw new ArgumentNullException(nameof(selector));
            }
            if (values == null)
            {
                throw new ArgumentNullException(nameof(values));
            }

            if (!values.Any()) return query;

            ParameterExpression p = selector.Parameters.Single();
            var containsExpressions = values.Select(value => (Expression)Expression.Call(selector.Body, typeof(String).GetMethod("Contains", new[] { typeof(String) }), Expression.Constant(value)));
            Expression body = containsExpressions.Aggregate((accameulate, containsExpression) => Expression.Or(accameulate, containsExpression));

            return query.Where(Expression.Lambda(body, p));
        }

        public static IQueryable WhereOrCollectionAnyEqual
            (
            this IQueryable query,
            Expression selector,
            Expression memberSelector,
            IEnumerable values
            )
        {
            if (selector == null)
            {
                throw new ArgumentNullException(nameof(selector));
            }
            if (values == null)
            {
                throw new ArgumentNullException(nameof(values));
            }

            if (!values.Any()) return query;

            ParameterExpression selectorParameter = selector.Parameters.Single();
            ParameterExpression memberParameter = memberSelector.Parameters.Single();
            var methodInfo = GetEnumerableMethod("Any", 2).MakeGenericMethod(typeof(TValue));
            var anyExpressions = values.Select(value =>
                    (Expression)Expression.Call(null,
                                                methodInfo,
                                                selector.Body,
                                                Expression.Lambda(Expression.Equal(memberSelector.Body,
                                                                                                       Expression.Constant(value, typeof(TMemberValue))),
                                                                                                       memberParameter
                                                                                                       )
                                                )
                );
            Expression body = anyExpressions.Aggregate((accameulate, any) => Expression.Or(accameulate, any));

            return query.Where(Expression.Lambda(body, selectorParameter));
        }

        /// 
        /// WhereIn
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IQueryable WhereIn
          (
            this IQueryable query,
            Expression selector,
            IEnumerable values
          )
        {
            /*
             * 实现效果:
             * var names = new[] { "A", "B", "C" };
             * SELECT * FROM [User] Where Name='A' OR Name='B' OR Name='C'
             * 实际上,可以直接这样:
             * var query = DbContext.User.Where(m => names.Contains(m.Name));
             */

            if (selector == null)
            {
                throw new ArgumentNullException(nameof(selector));
            }
            if (values == null)
            {
                throw new ArgumentNullException(nameof(values));
            }

            if (!values.Any()) return query;

            ParameterExpression p = selector.Parameters.Single();
            IEnumerable equals = values.Select(value => (Expression)Expression.Equal(selector.Body, Expression.Constant(value, typeof(TValue))));
            Expression body = equals.Aggregate((accameulate, equal) => Expression.Or(accameulate, equal));

            return query.Where(Expression.Lambda(body, p));
        }

        /// 
        /// WhereIn
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IQueryable WhereIn
          (
            this IQueryable query,
            Expression selector,
            params TValue[] values
          )
        {
            return WhereIn(query, selector, (IEnumerable)values);
        }

        /// 
        /// LeftJoin
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IQueryable LeftJoin(
            this IQueryable outer,
            IQueryable inner,
            Expression outerKeySelector,
            Expression innerKeySelector,
            Expression resultSelector)
        {
            MethodInfo groupJoin = typeof(Queryable).GetMethods()
                                                     .Single(m => m.ToString() == "System.Linq.IQueryable`1[TResult] GroupJoin[TOuter,TInner,TKey,TResult](System.Linq.IQueryable`1[TOuter], System.Collections.Generic.IEnumerable`1[TInner], System.Linq.Expressions.Expression`1[System.Func`2[TOuter,TKey]], System.Linq.Expressions.Expression`1[System.Func`2[TInner,TKey]], System.Linq.Expressions.Expression`1[System.Func`3[TOuter,System.Collections.Generic.IEnumerable`1[TInner],TResult]])")
                                                     .MakeGenericMethod(typeof(TOuter), typeof(TInner), typeof(TKey), typeof(LeftJoinIntermediate));
            MethodInfo selectMany = typeof(Queryable).GetMethods()
                                                      .Single(m => m.ToString() == "System.Linq.IQueryable`1[TResult] SelectMany[TSource,TCollection,TResult](System.Linq.IQueryable`1[TSource], System.Linq.Expressions.Expression`1[System.Func`2[TSource,System.Collections.Generic.IEnumerable`1[TCollection]]], System.Linq.Expressions.Expression`1[System.Func`3[TSource,TCollection,TResult]])")
                                                      .MakeGenericMethod(typeof(LeftJoinIntermediate), typeof(TInner), typeof(TResult));

            var groupJoinResultSelector = (Expression)
                                          ((oneOuter, manyInners) => new LeftJoinIntermediate { OneOuter = oneOuter, ManyInners = manyInners });

            MethodCallExpression exprGroupJoin = Expression.Call(groupJoin, outer.Expression, inner.Expression, outerKeySelector, innerKeySelector, groupJoinResultSelector);

            var selectManyCollectionSelector = (Expression)
                                               (t => t.ManyInners.DefaultIfEmpty());

            ParameterExpression paramUser = resultSelector.Parameters.First();

            ParameterExpression paramNew = Expression.Parameter(typeof(LeftJoinIntermediate), "t");
            MemberExpression propExpr = Expression.Property(paramNew, "OneOuter");

            LambdaExpression selectManyResultSelector = Expression.Lambda(new Replacer(paramUser, propExpr).Visit(resultSelector.Body) ?? throw new InvalidOperationException(), paramNew, resultSelector.Parameters.Skip(1).First());

            MethodCallExpression exprSelectMany = Expression.Call(selectMany, exprGroupJoin, selectManyCollectionSelector, selectManyResultSelector);

            return outer.Provider.CreateQuery(exprSelectMany);
        }

        private clast LeftJoinIntermediate
        {
            public TOuter OneOuter { get; set; }
            public IEnumerable ManyInners { get; set; }
        }

        private clast Replacer : ExpressionVisitor
        {
            private readonly ParameterExpression _oldParam;
            private readonly Expression _replacement;

            public Replacer(ParameterExpression oldParam, Expression replacement)
            {
                _oldParam = oldParam;
                _replacement = replacement;
            }

            public override Expression Visit(Expression exp)
            {
                if (exp == _oldParam)
                {
                    return _replacement;
                }

                return base.Visit(exp);
            }
        }

        /// 
        /// Order
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IOrderedQueryable Order(this IQueryable source, string propertyName, bool descending, bool anotherLevel = false)
        {
            var type = typeof(T);
            var propertyInfo = type.GetProperty(propertyName, BindingFlags.Instance | BindingFlags.IgnoreCase | BindingFlags.Public);
            if(propertyInfo == null)
            {
                throw new ArgumentOutOfRangeException(nameof(propertyName));
            }
            ParameterExpression parameter = Expression.Parameter(type, String.Empty); // I don't care about some naming
            MemberExpression property = Expression.Property(parameter, propertyInfo);
            LambdaExpression sort = Expression.Lambda(property, parameter);
            MethodCallExpression call = Expression.Call(
                typeof(Queryable),
                (!anotherLevel ? "OrderBy" : "ThenBy") + (descending ? "Descending" : String.Empty),
                new[] { typeof(T), property.Type },
                source.Expression,
                Expression.Quote(sort));
            return (IOrderedQueryable)source.Provider.CreateQuery(call);
        }

        /// 
        /// Order
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IOrderedQueryable Order(this IQueryable source, SortInfo sortInfo, bool anotherLevel = false)
        {
            return Order(source, sortInfo.Sort, sortInfo.SortDir == SortDir.DESC, anotherLevel);
        }

        /// 
        /// Order
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IOrderedQueryable Order(this IQueryable source, ICollection sortInfos)
        {
            IOrderedQueryable result = null;
            var isFirst = true;
            foreach(var sortInfo in sortInfos)
            {
                result = Order(source, sortInfo, !isFirst);
                isFirst = false;
            }
            return result;
        }

        /// 
        /// OrderBy
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IOrderedQueryable OrderBy(this IQueryable source, string propertyName)
        {
            return Order(source, propertyName, false, false);
        }

        /// 
        /// OrderByDescending
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IOrderedQueryable OrderByDescending(this IQueryable source, string propertyName)
        {
            return Order(source, propertyName, true, false);
        }

        /// 
        /// ThenBy
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IOrderedQueryable ThenBy(this IOrderedQueryable source, string propertyName)
        {
            return Order(source, propertyName, false, true);
        }

        /// 
        /// ThenByDescending
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IOrderedQueryable ThenByDescending(this IOrderedQueryable source, string propertyName)
        {
            return Order(source, propertyName, true, true);
        }

        /// 
        /// Filters a  by given predicate if given condition is true.
        /// 
        /// Queryable to apply filtering
        /// A boolean value
        /// Predicate to filter the query
        /// Filtered or not filtered query based on 
        /// https://github.com/aspnetboilerplate/aspnetboilerplate/blob/e0ded5d8702f389aa1f5947d3446f16aec845287/src/Abp/Linq/Extensions/QueryableExtensions.cs
        public static IQueryable WhereIf(this IQueryable query, bool condition, Expression predicate)
        {
            return condition
                ? query.Where(predicate)
                : query;
        }

        /// 
        /// Filters a  by given predicate if given condition is true.
        /// 
        /// Queryable to apply filtering
        /// A boolean value
        /// Predicate to filter the query
        /// Filtered or not filtered query based on 
        /// https://github.com/aspnetboilerplate/aspnetboilerplate/blob/e0ded5d8702f389aa1f5947d3446f16aec845287/src/Abp/Linq/Extensions/QueryableExtensions.cs
        public static IQueryable WhereIf(this IQueryable query, bool condition, Expression predicate)
        {
            return condition
                ? query.Where(predicate)
                : query;
        }

        private static MethodInfo GetEnumerableMethod(string name, int parameterCount = 0, Func predicate = null)
        {
            return typeof(Enumerable)
                .GetTypeInfo()
                .GetDeclaredMethods(name)
                .Single(_ => _.GetParameters().Length == parameterCount && (predicate == null || predicate(_)));
        }
    }
}