Reloading entity framework entry navigation properties

This helper class will reload all navigation properties (and reverse navigation) of an entity framework entry recursively.  This is useful when you just detach / reattach an entry and need to make sure all of its navigation are loaded. You should not use this on a set of data.

The helper contains lots of lines but in simple here is what it does.

  1. Get the entry from the context and create a list of navigation properties
  2. Iterate trough each navigation and ensure it is loaded
  3. If navigation is reverse (collection) ensure it is reloaded.

    public static class Extensions
    {

        public static TPoco ReloadNavigations<TPoco>(this DbContext ctx, TPoco poco, bool throwWhenNavNotFound = true, bool force = false)
            where TPoco: class
        {
            var res = ctx._reloadNavigations(poco,throwWhenNavNotFound, force);

            return (TPoco)res;
        }
        private static TPoco _reloadNavigations<TPoco>(this DbContext ctx, TPoco poco, bool throwWhenNavNotFound = true, bool force = false)
            where TPoco : class
        {
            var navProps = ctx.GetNavigationProperties(poco);

            foreach (var navProp in navProps)
            {
                var navValue = navProp.GetValue(poco);

                if (navValue == null && force)
                {
                    ctx.LoadNavigationProperty(poco, navProp, true, throwWhenNavNotFound);

                    //poco.LoadNavigationProperty(navProp, true, throwWhenNavNotFound, ctx);
                    continue;
                }
                if (navValue is System.Collections.IList && navValue.GetType().IsGenericType)
                {
                    var colPocos = (System.Collections.IList)navValue;
                    for (int i = colPocos.Count - 1; i >= 0; i--)
                    {
                        var colPoco = colPocos[i];
                        if (ctx.Entry(colPoco).State == System.Data.Entity.EntityState.Detached)
                        {
                            var keyValue = GetKey(colPoco);
                            var foundPoco = ctx.Set(colPoco.GetType()).Find(keyValue);
                            if (foundPoco == null && throwWhenNavNotFound)
                                throw new InvalidOperationException($"Unable to restore member reference {navProp.Name} of {poco}");

                            if (foundPoco != null)
                                colPocos[i] = foundPoco;
                            else if (foundPoco == null)
                            {
                                ctx._reloadNavigations(poco, throwWhenNavNotFound, force);
                            }

                        }
                    }

                    continue;
                }

                if (navValue != null && ctx.Entry(navValue).State != System.Data.Entity.EntityState.Detached &&
                    ctx.Entry(navValue).State != System.Data.Entity.EntityState.Added)
                    continue;

                ctx.LoadNavigationProperty(poco,navProp, true, throwWhenNavNotFound);
            }

            return poco;
        }

        internal static object GetKey(object objectInstance)
        {
            if (objectInstance == null)
                return null;

            var keyField = GetKeyField(objectInstance.GetType());

            if (string.IsNullOrEmpty(keyField))
                return null;


            var keyValueProp = objectInstance.GetType().GetProperty(keyField);

            if (keyField == null) return null;

            var keyValue = keyValueProp.GetValue(objectInstance);

            return keyValue;
        }
        private static string GetKeyField<T>()
        {
            var result = GetKeyField(typeof(T));

            return result;
        }
        private static string GetKeyField(Type type)
        {
            var keyAttrProp = type.GetProperties().Where(V => V.CustomAttributes.Where(V1 => V1.AttributeType == typeof(KeyAttribute)).Count() > 0).FirstOrDefault();

            string result;

            if (keyAttrProp != null)
                result = keyAttrProp.Name;
            else
                result = type.GetProperty("Id") != null ? "Id" : null;


            return result;

        }
        private static List<System.Reflection.PropertyInfo> GetNavigationProperties<TPoco>(this DbContext ctx , TPoco entity)
            where TPoco : class
        {
            List<System.Reflection.PropertyInfo> properties = new List<System.Reflection.PropertyInfo>();

            Type entityType = entity.GetType();

            IObjectContextAdapter iobjectCtx = ctx;
            EntityType entitySetElementType = null;

            var pocoType = entity.GetUnproxiedtype();

            var createObjectMethod = typeof(ObjectContext).GetMethods().Where(V => V.Name == "CreateObjectSet" && V.GetParameters().Count() == 0).First();
            var genMethod = createObjectMethod.MakeGenericMethod(pocoType);
            entitySetElementType = (EntityType)((dynamic)genMethod.Invoke(iobjectCtx.ObjectContext, null)).EntitySet.ElementType;


            foreach (var navigationProperty in entitySetElementType.NavigationProperties)
            {
                properties.Add(entityType.GetProperty(navigationProperty.Name));
            }
            return properties;
        }
        private static Type GetUnproxiedtype<TPoco>(this TPoco entity)
                where TPoco : class
        {
            var pocoType = entity.GetType();
            if (entity.GetType().FullName.StartsWith("System.Data.Entity.DynamicProxies"))
                pocoType = pocoType.BaseType;

            return pocoType;
        }

        private static object GetNavigationPropertyId<TPoco>(this TPoco poco, PropertyInfo prop)
            where TPoco : class
        {
            var value = prop.GetValue(poco);

            if (value == null) //Is there an equivalent *Id field 
            {
                var fkProp = poco.GetType().GetProperty(prop.Name + "Id");

                if (fkProp != null)
                    value = fkProp.GetValue(poco);
            }
            else
            {
                value = GetKey(value);
            }

            return value;

        }

        /// <summary>
        /// Will load the provided navigation property.
        /// </summary>
        /// <param name="navValue">The property info</param>
        /// <param name="force">If the nav. is already loaded, reload it.</param>
        private static void LoadNavigationProperty<TPoco>(this DbContext ctx, TPoco poco, PropertyInfo prop, bool force = false, bool throwWhenNavNotFound = true )
            where TPoco : class
        {
            var currVal = prop.GetValue(poco);
            
            if (currVal != null && ctx.Entry(poco).State != System.Data.Entity.EntityState.Detached && !force)
                return;

            var navId = poco.GetNavigationPropertyId(prop);
            if (navId == null)
                return;

            var obj = ctx.Set(prop.PropertyType).Find(navId);

            if (currVal != null && obj == null && throwWhenNavNotFound)
                throw new InvalidOperationException($"Unable to restore member reference {prop.Name} of {poco}");

            //  ctx.Entry(currVal).State = System.Data.Entity.EntityState.Detached;

            prop.SetValue(poco, obj);
        }

        private static PropertyInfo GetNavigationProperty<TPoco>(this TPoco poco, string propName)
            where TPoco : class
        {
            var result = poco.GetType().GetProperty(propName);

            return result;
        }

        private static object GetNavigationPropertyId<TPoco>(this TPoco poco, Expression<Func<TPoco, object>> prop)
            where TPoco : class
        {
            return poco.GetNavigationPropertyId(poco.GetNavigationProperty(prop.GetPropertyNavigation()));
        }

        private static string GetPropertyNavigation(this Expression body, bool throwOnMethod = true)
        {
            var memberInfos = body.GetMembersInfo();
            // var sb = new StringBuilder();
            var str = string.Empty;

            foreach (var memberInfo in memberInfos)
            {
                if (memberInfo == null) continue;

                str += (ParseMemberPath(memberInfo));
            }

            return str;
            /*
            var memberInfo = body.GetMemberInfo(throwOnMethod: throwOnMethod);
            if (memberInfo == null)
                return string.Empty;

            return ParseMemberPath(memberInfo);*/
        }
        private static MemberExpression GetLamdaMember(Expression body, bool throwOnMethod = true)
        {
            if (body.NodeType == ExpressionType.Convert)
            {
                return ((UnaryExpression)body).Operand as MemberExpression;
            }
            else if (body.NodeType == ExpressionType.MemberAccess)
            {
                return body as MemberExpression;
            }
            else if (body.NodeType == ExpressionType.Call)
            {
                var methodCall = body as System.Linq.Expressions.MethodCallExpression;
                StringBuilder sb = new StringBuilder();
                foreach (var param in methodCall.Arguments)
                {
                    var path = ParseMemberPath(GetMemberInfo(param));
                    sb.AppendLine(path);
                }
            }

            // unhandled.
            if (throwOnMethod)
                throw new ArgumentException("method");
            return null;
        }
        public static string ParseMemberPath(List<MemberExpression> expressions)
        {
            string ret = "";
            for (var i = 0; i < expressions.Count; i++)
            {
                if (i != 0)
                {
                    ret += ".";
                }

                ret += expressions[i].Member.Name;
            }

            return ret;
        }

        public static List<MemberExpression> GetMemberInfo(this Expression method, bool throwOnMethod = true)
        {
            // cast the lamba expression.
            LambdaExpression lambda = method as LambdaExpression;
            if (lambda == null)
                throw new ArgumentNullException("method");

            // return value.
            var ret = new List<MemberExpression>();

            // top.

            var member = GetLamdaMember(lambda.Body, throwOnMethod: throwOnMethod);
            if (member == null)
                return null;


            ret.Insert(0, member);

            // each parent.
            while (ret[0].Expression.NodeType != ExpressionType.Parameter)
            {
                member = GetLamdaMember(ret[0].Expression, throwOnMethod: throwOnMethod);
                if (member == null)
                    continue;
                ret.Insert(0, member);
            }



            return ret;
        }

        private static List<List<MemberExpression>> GetMembersInfo(this Expression method, bool throwOnMethod = true)
        {
            LambdaExpression lambda = method as LambdaExpression;
            if (lambda == null)
                throw new ArgumentNullException("method");

            List<List<MemberExpression>> res = new List<List<MemberExpression>>();
            if (lambda.Body.NodeType == ExpressionType.Call)
            {
                var methodCall = lambda.Body as System.Linq.Expressions.MethodCallExpression;
                foreach (var param in methodCall.Arguments)
                {
                    res.Add(GetMemberInfo(param, throwOnMethod));
                }
            }
            else
                res.Add(GetMemberInfo(lambda));

            return res;
        }
    }

Submit a Comment

Your email address will not be published. Required fields are marked *

Share This