Implementing the Repository pattern with direct IQueryable support

Lately I have written a lot of posts about the Repository and Unit of work patterns. You can check out all the software design posts about why I think they are (still) useful and how you can use the Specification pattern to improve on it. So this pattern has been on my mind a lot.

The usual implementation

The usual implementation for Entity Framework looks something like this:

public interface IRepository<T>
{        
  void Add(T entity);
  void Delete(T entity);
  void Update(T entity);
  IQueryable<T> List();
}

This is just a basic setup, you might have extra read-type methods, or maybe some parameters for the List() method. And then, you have an implementation like this for a model class like Category:

public class CategoryRepository : IRepository<Category>
{
  private readonly DbSet<Category> targetDbSet;
  private readonly DbContext dbContext;
  
  public CategoryRepository(DbContext context)
  {
    this.dbContext = context;
    this.targetDbSet = dbContext.Stt<Category>();
  }
  
  public void Add(Category entity)
  {
    targetDbSet.Add(entity);
  }
  
  public void Delete(Category entity)
  {
    var entry = dbContext.Entry(entity);
    if (entry == null || entry.State == EntityState.Detached)
    {
      targetDbSet.Attach(entity);
    }
    entry.State = EntityState.Deleted;
  }
  
  public void Update(Category entity)
  {
    var entry = dbContext.Entry(entity);
    if (entry == null || entry.State == EntityState.Detached)
    {
      targetDbSet.Attach(entity);
    }
    entry.State = EntityState.Modified;
  }  
  
  public IQueryable<Category> List()
  {
    return targetDbSet;
  }
}

Or something along these lines. Then you probabaly know the drill: create a repository for each entity, then create a unit of work to handle transactions and boom, you have a nicely architected, loosely coupled, cool architecture (at this level at least).

Then, you have to use it like this:

unitOfWork.CategoryRepository.List().Where(c=>c.CategoryId>8).ToList();

Not bad, but I see two possible problems with this. You always have to use the List() method. How cool would it be to simple call the LInQ methods on the repository? And this would also solve the other problem. Now, you have an IQueryable<T> returned from the repository interface. And some people think this is a bad idea (you can check my specification posts on why). But if you make the repository itself the IQueryable<T>, that's a whole different story. So why not? Basically, you I just had to somehow implement my own LInQ provider. What could possibly be hard in that, right?

Defining the interface

First things first: let's define the interface. I knew that I wanted all the C.U.D. operations present, with the possibility to do the read-like operations directly using LInQ. So a repository interface might look like this:

public interface IRepository<T> : IOrderedQueryable<T> where T : class
{        
  void Add(T entity);
  void Delete(T entity);
  void Update(T entity);          
}

Implementing the interface

I have created LInQ providers before and used a very easy trick: basically I created a 'wrapper-provider' for the actual provider.

I used the same idea here: create a LInQ provider that, when executing the query, simply "replaces" the base of the query from the repository to the actual DbSet<T>. Implementing the IQueryable<T> (or in this case, the IOrderedQueryable<T>) is actually not the big part of the task:

public class RepositoryBase<T> : IRepository<T> where T : class
{

  private readonly DbContext dbContext;
  private readonly DbSet<T> targetDbSet;             
  
  public RepositoryBase(DbContext dbContext)
  {
    this.dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext));
    this.targetDbSet = dbContext.Set<T>();
    Expression = Expression.Constant(this);
    this.Provider = new RepositoryBaseQueryProvider<T>(targetDbSet);
  }
  
  public RepositoryBase(IQueryProvider provider, Expression expression)
  {
    Provider = provider;
    Expression = expression;
   }
   
   public Type ElementType => typeof(T);   
   public Expression Expression { get; }
   public IQueryProvider Provider { get; }
   public IEnumerator<T> GetEnumerator() => 
                        Provider.Execute<IEnumerable<T>>(Expression).GetEnumerator();
   
   IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
        
   public void Add(T entity)
   {
     targetDbSet.Add(entity);
   }
   
   public void Delete(T entity)
   {
     var entry = dbContext.Entry(entity);
     if (entry == null || entry.State == EntityState.Detached)
     {
       targetDbSet.Attach(entity);
     }
     entry.State = EntityState.Deleted;
   }
   
   public void Update(T entity)
   {
     var entry = dbContext.Entry(entity);
     if (entry == null || entry.State == EntityState.Detached)
     {
       targetDbSet.Attach(entity);
     }
     entry.State = EntityState.Modified;
   }                
}

So, the key points here:

  • Basically, this is an implementation for the interface.
  • The implementation takes an instance of the DbContext, like every other repository would do. This instance is used for the usual C.U.D. operations. Nothing fancy here.
  • Then, the interface is implemented using pretty standard methods. The properties are simply added to the class, the GetEnumerator() simply executes the query and calls into its GetEnumerator(), the Expression is initialized to the current instance.
  • The magic happens inside the query provider: that's the component that's actually responsible for executing the query. An instance of this query provider, RepositoryBaseQueryProvider<T> is stored in the Provider property of the queryable.

Implementing the query provider

Even though the technical aspects are not so easy, the query provider itself is structured pretty simple. You have to implement an interface, which has two methods: one for creating the query and for executing it (actually, there are four methods, the generic and the non-generic version for both, but the non-generic simply uses reflection magic to call the generic):

public class RepositoryBaseQueryProvider<TEntity> : IQueryProvider where TEntity : class
{
  private readonly Type queryType;
  private readonly DbSet<TEntity> targetDbSet;

  public RepositoryBaseQueryProvider(DbSet<TEntity> targetDbSet)
  {
    this.queryType = typeof(RepositoryBase<>);
    this.targetDbSet = targetDbSet;
  }

  public IQueryable CreateQuery(Expression expression)
  {
    var elementType = expression.Type.GetElementTypeForExpression();
    try
    {
      return (IQueryable)Activator.CreateInstance(queryType.MakeGenericType(elementType), new object[] { this, expression });
    }
    catch (TargetInvocationException tie)
    {
      throw tie.InnerException;
    }
  }

  public object Execute(Expression expression)
  {            
    try
    {
      return this.GetType().GetGenericMethod(nameof(Execute)).Invoke(this, new[] { expression });
    }
    catch (TargetInvocationException tie)
    {
      throw tie.InnerException;
    }
  }

  // See https://msdn.microsoft.com/en-us/library/bb546158.aspx for more details
  public TResult Execute<TResult>(Expression expression)
  {             
    IQueryable<TEntity> newRoot = targetDbSet;
    var treeCopier = new RootReplacingVisitor(newRoot);
    var newExpressionTree = treeCopier.Visit(expression);
    var isEnumerable = (typeof(TResult).IsGenericType && typeof(TResult).GetGenericTypeDefinition() == typeof(IEnumerable<>));
    if (isEnumerable)
    {
      return (TResult)newRoot.Provider.CreateQuery(newExpressionTree);
    }
    var result = newRoot.Provider.Execute(newExpressionTree);
    return (TResult)result;
  }

  public IQueryable<T> CreateQuery<T>(Expression expression)
  {
    var elementType = expression.Type.GetElementTypeForExpression();
    var type = queryType.MakeGenericType(elementType);
    return (IQueryable<T>)Activator.CreateInstance(type, new object[] { this, expression });
  }      
}

The CreateQuery() simply creates an instance of the IQueryable implementation class with the expression. The execute method is where the magic happens: it uses an custom built ExpressionVisitor to replace the root of the expression tree — that is, to change the RepositoryBase<T> to the DbSet<T>. The code for the visitor looks like this:

internal class RootReplacingVisitor : ExpressionVisitor
{
  private readonly IQueryable newRoot;
  public RootReplacingVisitor(IQueryable newRoot)
  {
    this.newRoot = newRoot;
  }
  
  protected override Expression VisitConstant(ConstantExpression node) => 
             node.Type.BaseType!=null && node.Type.BaseType.IsGenericType && node.Type.BaseType.GetGenericTypeDefinition() == typeof(RepositoryBase<>) ? Expression.Constant(newRoot) : node;            
        
}

Testing it out

And now, drumrolls please. The newly structured repository for the Category entity looks like this:

public interface ICategoryRepository : IRepository<Category>
{
}

public class CategoryRepository : RepositoryBase<Category>, ICategoryRepository
{
  public CategoryRepository(DbContext dbContext) : base(dbContext)
  {
  }
}

The unit of work interface stays the same (actually, the implementation stays the same as well):

public interface IUnitOfWork
{
  ICategoryRepository Categories { get;  }
  IProductRepository Products { get; }
  Task CommitAsync();
}

And here's how you use it:

using (var x = new ApplicationContext())
{
  IUnitOfWork uow = new UnitOfWork(new CategoryRepository(x), new ProductRepository(x), x);
  var s = uow.Categories.Where(c => c.CategoryId > 3).ToList();                
}  

How cool is that? Very cool. It doesn't have asnyc support yet, but stay tuned, because I'm on it. Until then, you can check out the the full code with a sample based on the awesome Northwind database on Github