Skip to content

Commit

Permalink
Merge pull request #922 from jbogard/explicit-processor-registration
Browse files Browse the repository at this point in the history
Explicit processor registration
  • Loading branch information
jbogard authored Jul 7, 2023
2 parents 4452ce8 + 3e1c399 commit 9fbba24
Show file tree
Hide file tree
Showing 4 changed files with 442 additions and 52 deletions.
251 changes: 251 additions & 0 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Reflection;
using MediatR;
using MediatR.NotificationPublishers;
using MediatR.Pipeline;

namespace Microsoft.Extensions.DependencyInjection;

Expand Down Expand Up @@ -52,6 +53,16 @@ public class MediatRServiceConfiguration
/// </summary>
public List<ServiceDescriptor> StreamBehaviorsToRegister { get; } = new();

/// <summary>
/// List of request pre processors to register in specific order
/// </summary>
public List<ServiceDescriptor> RequestPreProcessorsToRegister { get; } = new();

/// <summary>
/// List of request post processors to register in specific order
/// </summary>
public List<ServiceDescriptor> RequestPostProcessorsToRegister { get; } = new();

/// <summary>
/// Register various handlers from assembly containing given type
/// </summary>
Expand Down Expand Up @@ -103,6 +114,41 @@ public MediatRServiceConfiguration RegisterServicesFromAssemblies(
public MediatRServiceConfiguration AddBehavior<TServiceType, TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed behavior type against all <see cref="IPipelineBehavior{TRequest,TResponse}"/> implementations
/// </summary>
/// <typeparam name="TImplementationType">Closed behavior implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddBehavior<TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
return AddBehavior(typeof(TImplementationType), serviceLifetime);
}

/// <summary>
/// Register a closed behavior type against all <see cref="IPipelineBehavior{TRequest,TResponse}"/> implementations
/// </summary>
/// <param name="implementationType">Closed behavior implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IPipelineBehavior<,>)));

if (implementedBehaviorTypes.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}");
}

foreach (var implementedBehaviorType in implementedBehaviorTypes)
{
BehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
}

return this;
}

/// <summary>
/// Register a closed behavior type
/// </summary>
Expand Down Expand Up @@ -170,6 +216,39 @@ public MediatRServiceConfiguration AddStreamBehavior(Type serviceType, Type impl
return this;
}

/// <summary>
/// Register a closed stream behavior type against all <see cref="IStreamPipelineBehavior{TRequest,TResponse}"/> implementations
/// </summary>
/// <typeparam name="TImplementationType">Closed stream behavior implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddStreamBehavior<TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddStreamBehavior(typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed stream behavior type against all <see cref="IStreamPipelineBehavior{TRequest,TResponse}"/> implementations
/// </summary>
/// <param name="implementationType">Closed stream behavior implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddStreamBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IStreamPipelineBehavior<,>)));

if (implementedBehaviorTypes.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}");
}

foreach (var implementedBehaviorType in implementedBehaviorTypes)
{
StreamBehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
}

return this;
}

/// <summary>
/// Registers an open stream behavior type against the <see cref="IStreamPipelineBehavior{TRequest,TResponse}"/> open generic interface type
/// </summary>
Expand Down Expand Up @@ -199,5 +278,177 @@ public MediatRServiceConfiguration AddOpenStreamBehavior(Type openBehaviorType,
return this;
}

/// <summary>
/// Register a closed request pre processor type
/// </summary>
/// <typeparam name="TServiceType">Closed request pre processor interface type</typeparam>
/// <typeparam name="TImplementationType">Closed request pre processor implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor<TServiceType, TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddRequestPreProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed request pre processor type
/// </summary>
/// <param name="serviceType">Closed request pre processor interface type</param>
/// <param name="implementationType">Closed request pre processor implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
RequestPreProcessorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime));

return this;
}

/// <summary>
/// Register a closed request pre processor type against all <see cref="IRequestPreProcessor{TRequest}"/> implementations
/// </summary>
/// <typeparam name="TImplementationType">Closed request pre processor implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor<TImplementationType>(
ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddRequestPreProcessor(typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed request pre processor type against all <see cref="IRequestPreProcessor{TRequest}"/> implementations
/// </summary>
/// <param name="implementationType">Closed request pre processor implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedPreProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>)));

if (implementedPreProcessorTypes.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}");
}

foreach (var implementedPreProcessorType in implementedPreProcessorTypes)
{
RequestPreProcessorsToRegister.Add(new ServiceDescriptor(implementedPreProcessorType, implementationType, serviceLifetime));
}

return this;
}

/// <summary>
/// Registers an open request pre processor type against the <see cref="IRequestPreProcessor{TRequest}"/> open generic interface type
/// </summary>
/// <param name="openBehaviorType">An open generic request pre processor type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddOpenRequestPreProcessor(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
if (!openBehaviorType.IsGenericType)
{
throw new InvalidOperationException($"{openBehaviorType.Name} must be generic");
}

var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedOpenBehaviorInterfaces = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>)));

if (implementedOpenBehaviorInterfaces.Count == 0)
{
throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}");
}

foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces)
{
RequestPreProcessorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime));
}

return this;
}

/// <summary>
/// Register a closed request post processor type
/// </summary>
/// <typeparam name="TServiceType">Closed request post processor interface type</typeparam>
/// <typeparam name="TImplementationType">Closed request post processor implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor<TServiceType, TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddRequestPostProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed request post processor type
/// </summary>
/// <param name="serviceType">Closed request post processor interface type</param>
/// <param name="implementationType">Closed request post processor implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor(Type serviceType, Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
RequestPostProcessorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime));

return this;
}

/// <summary>
/// Register a closed request post processor type against all <see cref="IRequestPostProcessor{TRequest,TResponse}"/> implementations
/// </summary>
/// <typeparam name="TImplementationType">Closed request post processor implementation type</typeparam>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor<TImplementationType>(ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
=> AddRequestPostProcessor(typeof(TImplementationType), serviceLifetime);

/// <summary>
/// Register a closed request post processor type against all <see cref="IRequestPostProcessor{TRequest,TResponse}"/> implementations
/// </summary>
/// <param name="implementationType">Closed request post processor implementation type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedPostProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>)));

if (implementedPostProcessorTypes.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}");
}

foreach (var implementedPostProcessorType in implementedPostProcessorTypes)
{
RequestPostProcessorsToRegister.Add(new ServiceDescriptor(implementedPostProcessorType, implementationType, serviceLifetime));
}
return this;
}

/// <summary>
/// Registers an open request post processor type against the <see cref="IRequestPostProcessor{TRequest,TResponse}"/> open generic interface type
/// </summary>
/// <param name="openBehaviorType">An open generic request post processor type</param>
/// <param name="serviceLifetime">Optional service lifetime, defaults to <see cref="ServiceLifetime.Transient"/>.</param>
/// <returns>This</returns>
public MediatRServiceConfiguration AddOpenRequestPostProcessor(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
if (!openBehaviorType.IsGenericType)
{
throw new InvalidOperationException($"{openBehaviorType.Name} must be generic");
}

var implementedGenericInterfaces = openBehaviorType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedOpenBehaviorInterfaces = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>)));

if (implementedOpenBehaviorInterfaces.Count == 0)
{
throw new InvalidOperationException($"{openBehaviorType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}");
}

foreach (var openBehaviorInterface in implementedOpenBehaviorInterfaces)
{
RequestPostProcessorsToRegister.Add(new ServiceDescriptor(openBehaviorInterface, openBehaviorType, serviceLifetime));
}

return this;
}


}
22 changes: 13 additions & 9 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@ public static void AddMediatRClasses(IServiceCollection services, MediatRService
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IStreamRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestPreProcessor<>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestPostProcessor<,>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestExceptionHandler<,,>), services, assembliesToScan, true, configuration);
ConnectImplementationsToTypesClosing(typeof(IRequestExceptionAction<,>), services, assembliesToScan, true, configuration);

var multiOpenInterfaces = new[]
{
typeof(INotificationHandler<>),
typeof(IRequestPreProcessor<>),
typeof(IRequestPostProcessor<,>),
typeof(IRequestExceptionHandler<,,>),
typeof(IRequestExceptionAction<,>)
};
Expand Down Expand Up @@ -224,6 +220,19 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi

services.TryAdd(notificationPublisherServiceDescriptor);

// Register pre processors, then post processors, then behaviors
if (serviceConfiguration.RequestPreProcessorsToRegister.Any())
{
services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), typeof(RequestPreProcessorBehavior<,>), ServiceLifetime.Transient));
services.TryAddEnumerable(serviceConfiguration.RequestPreProcessorsToRegister);
}

if (serviceConfiguration.RequestPostProcessorsToRegister.Any())
{
services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), typeof(RequestPostProcessorBehavior<,>), ServiceLifetime.Transient));
services.TryAddEnumerable(serviceConfiguration.RequestPostProcessorsToRegister);
}

foreach (var serviceDescriptor in serviceConfiguration.BehaviorsToRegister)
{
services.TryAddEnumerable(serviceDescriptor);
Expand All @@ -234,11 +243,6 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
services.TryAddEnumerable(serviceDescriptor);
}

// Use built-in Microsoft TryAddEnumerable method, we do want to register our Pre/Post processor behavior,
// even if (a more concrete) registration for IPipelineBehavior<,> already exists. But only once.
RegisterBehaviorIfImplementationsExist(services, typeof(RequestPreProcessorBehavior<,>), typeof(IRequestPreProcessor<>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestPostProcessorBehavior<,>), typeof(IRequestPostProcessor<,>));

if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ public async Task Should_not_call_constructor_multiple_times_when_using_a_pipeli

services.AddSingleton(output);
services.AddTransient(typeof(IPipelineBehavior<,>), typeof(ConstructorTestBehavior<,>));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly));
services.AddMediatR(cfg =>
{
cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly);
cfg.AddOpenBehavior(typeof(ConstructorTestBehavior<,>));
});
var provider = services.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();
Expand All @@ -93,11 +97,7 @@ public async Task Should_not_call_constructor_multiple_times_when_using_a_pipeli
output.Messages.ShouldBe(new[]
{
"ConstructorTestBehavior before",
"First pre processor",
"Next pre processor",
"Handler",
"First post processor",
"Next post processor",
"ConstructorTestBehavior after"
});
ConstructorTestHandler.ConstructorCallCount.ShouldBe(1);
Expand Down
Loading

0 comments on commit 9fbba24

Please sign in to comment.