Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Source/Mockolate.SourceGenerators/Entities/Type.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ internal Type(ITypeSymbol typeSymbol)
.ToArray());
}

if (typeSymbol is IArrayTypeSymbol arrayTypeSymbol)
{
ElementType = From(arrayTypeSymbol.ElementType);
}

SpecialGenericType = typeSymbol.GetSpecialType();
SpecialType = typeSymbol.SpecialType;
CanBeNullable = typeSymbol.NullableAnnotation == NullableAnnotation.Annotated ||
Expand All @@ -53,6 +58,11 @@ typeSymbol is INamedTypeSymbol
public SpecialGenericType SpecialGenericType { get; }
public EquatableArray<Type>? TupleTypes { get; }
public EquatableArray<Type>? GenericTypeParameters { get; }

/// <summary>
/// The element type when this type is an array (e.g. <c>bool</c> for <c>bool[]</c>); otherwise <see langword="null" />.
/// </summary>
public Type? ElementType { get; }
public string? Namespace { get; }

internal static Type Void { get; } = new("void");
Expand Down
171 changes: 157 additions & 14 deletions Source/Mockolate.SourceGenerators/Sources/Sources.MockClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3235,6 +3235,12 @@ bool MethodPredicate(Method method)
{
AppendMethodSetupDefinition(sb, @class, method, false,
hasOverloadResolutionPriority: hasOverloadResolutionPriority);
if (TryGetPerElementParamsParameter(method))
{
AppendMethodSetupDefinition(sb, @class, method, false,
hasOverloadResolutionPriority: hasOverloadResolutionPriority, perElementParams: true);
}

if (method.Parameters.Count <= MaxExplicitParameters)
{
foreach (bool[] valueFlags in GenerateValueFlagCombinations(method.Parameters))
Expand Down Expand Up @@ -3312,9 +3318,49 @@ private static void AppendOverloadDifferentiatorRemark(StringBuilder sb,
sb.AppendXmlRemarks(text);
}

/// <summary>
/// Detects whether <paramref name="method" /> ends in a <c>params T[]</c> parameter that can carry a
/// per-element matcher overload. The element type must flow through the regular <c>IParameter&lt;T&gt;</c>
/// pipeline; ref-struct element types are intentionally not supported, as they cannot satisfy
/// <c>IParameter&lt;T&gt;</c>.
/// </summary>
private static bool TryGetPerElementParamsParameter(Method method)
{
if (method.Parameters.Count == 0)
{
return false;
}

MethodParameter last = method.Parameters.AsArray()[method.Parameters.Count - 1];
if (!last.IsParams || last.Type.ElementType is null || last.NeedsRefStructPipeline())
{
return false;
}

return true;
}

/// <summary>
/// True when <paramref name="valueFlags" /> marks the trailing <c>params T[]</c> parameter as a
/// literal value. Such overloads render the parameter as <c>params T[]</c> and match it element-wise
/// by value (via <c>It.SequenceEquals</c>) rather than by whole-array reference equality.
/// </summary>
private static bool HasParamsValueParameter(Method method, bool[]? valueFlags)
{
if (valueFlags is null || method.Parameters.Count == 0)
{
return false;
}

int last = method.Parameters.Count - 1;
MethodParameter parameter = method.Parameters.AsArray()[last];
return valueFlags[last] && parameter.IsParams && parameter.Type.ElementType is not null &&
!parameter.NeedsRefStructPipeline();
}

private static void AppendMethodSetupDefinition(StringBuilder sb, Class @class, Method method,
bool useParameters, string? methodNameOverride = null, bool[]? valueFlags = null,
bool hasOverloadResolutionPriority = false)
bool hasOverloadResolutionPriority = false, bool perElementParams = false)
{
// Methods using a generic type parameter that declares `allows ref struct` cannot expose
// a setup surface: IReturnMethodSetup<T> / IVoidMethodSetup<T> do not carry the same
Expand Down Expand Up @@ -3478,8 +3524,19 @@ private static void AppendMethodSetupDefinition(StringBuilder sb, Class @class,
}

bool isValueParam = valueFlags?[i] == true;
if (isValueParam)
if (perElementParams && parameter.IsParams &&
parameter.Type.ElementType is not null)
{
sb.Append("params global::Mockolate.Parameters.IParameter<")
.Append(parameter.Type.ElementType.Fullname).Append(">[] ").Append(parameter.Name);
}
else if (isValueParam)
{
if (parameter.IsParams && parameter.Type.ElementType is not null)
{
sb.Append("params ");
}

sb.Append(parameter.ToNullableType()).Append(' ').Append(parameter.Name);
}
else
Expand Down Expand Up @@ -3645,6 +3702,12 @@ bool MethodPredicate(Method method)
{
AppendMethodSetupImplementation(sb, method, mockRegistryName, setupName, false,
memberIds, memberIdPrefix, scopeExpression: scopeExpression);
if (TryGetPerElementParamsParameter(method))
{
AppendMethodSetupImplementation(sb, method, mockRegistryName, setupName, false,
memberIds, memberIdPrefix, scopeExpression: scopeExpression, perElementParams: true);
}

if (method.Parameters.Count <= MaxExplicitParameters)
{
foreach (bool[] valueFlags in GenerateValueFlagCombinations(method.Parameters))
Expand Down Expand Up @@ -3674,7 +3737,7 @@ private static void AppendMethodSetupImplementation(StringBuilder sb, Method met
string setupName,
bool useParameters, MemberIdTable memberIds, string memberIdPrefix,
string? methodNameOverride = null, bool[]? valueFlags = null,
string? scopeExpression = null)
string? scopeExpression = null, bool perElementParams = false)
{
// Setup-side carve-out: methods using a generic type parameter that declares
// `allows ref struct` have no setup interface declaration (see
Expand Down Expand Up @@ -3773,8 +3836,19 @@ private static void AppendMethodSetupImplementation(StringBuilder sb, Method met
}

bool isValueParam = valueFlags?[i] == true;
if (isValueParam)
if (perElementParams && parameter.IsParams &&
parameter.Type.ElementType is not null)
{
sb.Append("params global::Mockolate.Parameters.IParameter<")
.Append(parameter.Type.ElementType.Fullname).Append(">[] ").Append(parameter.Name);
}
else if (isValueParam)
{
if (parameter.IsParams && parameter.Type.ElementType is not null)
{
sb.Append("params ");
}

sb.Append(parameter.ToNullableType()).Append(' ').Append(parameter.Name);
}
else
Expand Down Expand Up @@ -3863,8 +3937,11 @@ private static void AppendMethodSetupImplementation(StringBuilder sb, Method met
// skip the per-parameter IParameterMatch<T> allocations that WithParameterCollection
// would otherwise force via It.IsValue<T>(...). Gated to 1..4 parameters because that is
// the arity range covered by the WithLiteralValues nested types.
// A params value parameter matches element-wise by value (It.SequenceEquals), which the
// reference-equality WithLiteralValues fast path cannot express — force the collection path.
bool useLiteralValues = valueFlags is { Length: > 0 and <= MaxExplicitParameters, } &&
valueFlags.All(x => x) &&
!HasParamsValueParameter(method, valueFlags) &&
!method.Parameters.Any(p => p.RefKind == RefKind.Out ||
p.RefKind == RefKind.Ref ||
p.RefKind == RefKind.RefReadOnlyParameter);
Expand All @@ -3887,9 +3964,27 @@ private static void AppendMethodSetupImplementation(StringBuilder sb, Method met
foreach (MethodParameter parameter in method.Parameters)
{
sb.Append(", ");
if (valueFlags?[j] == true)
if (perElementParams && parameter.IsParams &&
parameter.Type.ElementType is not null)
{
AppendNamedValueParameter(sb, parameter);
sb.Append("new global::Mockolate.Parameters.ParamsArrayParameterMatch<")
.Append(parameter.Type.ElementType.Fullname).Append(">(").Append(parameter.Name)
.Append(")");
}
else if (valueFlags?[j] == true)
{
if (parameter.IsParams && parameter.Type.ElementType is not null)
{
// params value parameter: match element-wise by value via It.SequenceEquals.
sb.Append("CovariantParameterAdapter<").Append(parameter.Type.Fullname)
.Append(">.Wrap(global::Mockolate.It.SequenceEquals<")
.Append(parameter.Type.ElementType.Fullname).Append(">(").Append(parameter.Name)
.Append("))");
}
else
{
AppendNamedValueParameter(sb, parameter);
}
}
else
{
Expand Down Expand Up @@ -4987,6 +5082,12 @@ bool MethodPredicate(Method method)
{
AppendMethodVerifyDefinition(sb, method, verifyName, false,
hasOverloadResolutionPriority: hasOverloadResolutionPriority);
if (TryGetPerElementParamsParameter(method))
{
AppendMethodVerifyDefinition(sb, method, verifyName, false,
hasOverloadResolutionPriority: hasOverloadResolutionPriority, perElementParams: true);
}

if (method.Parameters.Count <= MaxExplicitParameters)
{
foreach (bool[] valueFlags in GenerateValueFlagCombinations(method.Parameters))
Expand Down Expand Up @@ -5029,7 +5130,7 @@ bool MethodPredicate(Method method)

private static void AppendMethodVerifyDefinition(StringBuilder sb, Method method, string verifyName,
bool useParameters, string? methodNameOverride = null, bool[]? valueFlags = null,
bool hasOverloadResolutionPriority = false)
bool hasOverloadResolutionPriority = false, bool perElementParams = false)
{
// For methods with ref-struct parameters, skip Verify emission entirely. The
// VerificationResult pipeline takes IParameter<T>? matchers that then feed into
Expand Down Expand Up @@ -5115,8 +5216,19 @@ private static void AppendMethodVerifyDefinition(StringBuilder sb, Method method
}

bool isValueParam = valueFlags?[i] == true;
if (isValueParam)
if (perElementParams && parameter.IsParams &&
parameter.Type.ElementType is not null)
{
sb.Append("params global::Mockolate.Parameters.IParameter<")
.Append(parameter.Type.ElementType.Fullname).Append(">[] ").Append(parameter.Name);
}
else if (isValueParam)
{
if (parameter.IsParams && parameter.Type.ElementType is not null)
{
sb.Append("params ");
}

sb.Append(parameter.ToNullableType()).Append(' ').Append(parameter.Name);
}
else
Expand Down Expand Up @@ -5255,6 +5367,12 @@ bool MethodPredicate(Method method)
{
AppendMethodVerifyImplementation(sb, method, mockRegistryName, verifyName, false,
memberIds, memberIdPrefix, useFastBuffers);
if (TryGetPerElementParamsParameter(method))
{
AppendMethodVerifyImplementation(sb, method, mockRegistryName, verifyName, false,
memberIds, memberIdPrefix, useFastBuffers, perElementParams: true);
}

if (method.Parameters.Count <= MaxExplicitParameters)
{
foreach (bool[] valueFlags in GenerateValueFlagCombinations(method.Parameters))
Expand Down Expand Up @@ -5317,7 +5435,7 @@ bool MethodPredicate(Method method)
private static void AppendMethodVerifyImplementation(StringBuilder sb, Method method, string mockRegistryName,
string verifyName,
bool useParameters, MemberIdTable memberIds, string memberIdPrefix, bool useFastBuffers,
string? methodNameOverride = null, bool[]? valueFlags = null)
string? methodNameOverride = null, bool[]? valueFlags = null, bool perElementParams = false)
#pragma warning restore S107
{
// Mirror the AppendMethodVerifyDefinition short-circuit for ref-struct signatures.
Expand Down Expand Up @@ -5355,8 +5473,19 @@ private static void AppendMethodVerifyImplementation(StringBuilder sb, Method me
}

bool isValueParam = valueFlags?[i] == true;
if (isValueParam)
if (perElementParams && parameter.IsParams &&
parameter.Type.ElementType is not null)
{
sb.Append("params global::Mockolate.Parameters.IParameter<")
.Append(parameter.Type.ElementType.Fullname).Append(">[] ").Append(parameter.Name);
}
else if (isValueParam)
{
if (parameter.IsParams && parameter.Type.ElementType is not null)
{
sb.Append("params ");
}

sb.Append(parameter.ToNullableType()).Append(' ').Append(parameter.Name);
}
else
Expand Down Expand Up @@ -5395,8 +5524,13 @@ private static void AppendMethodVerifyImplementation(StringBuilder sb, Method me

bool canUseLiteralVerify = baseEligible &&
valueFlags is { Length: > 0 and <= 4, } &&
valueFlags.All(x => x);
bool canUseTypedVerify = useFastForMethod
valueFlags.All(x => x) &&
!HasParamsValueParameter(method, valueFlags);
// The per-element params overload passes a `params IParameter<TElement>[]` argument, which neither the
// literal nor the typed fast path can render (both assume a whole-array IParameter<TElement[]>). Route it
// through the slow predicate path, where the composite matcher is built explicitly.
bool canUseTypedVerify = !perElementParams
&& useFastForMethod
&& !useParameters
&& method.Parameters.Count <= 4
&& (method.GenericParameters is null || method.GenericParameters.Value.Count == 0)
Expand Down Expand Up @@ -5491,10 +5625,19 @@ private static void AppendMethodVerifyImplementation(StringBuilder sb, Method me
sb.AppendLine().Append("\t\t\t\t");

bool isValueParam = valueFlags?[i] == true;
if (isValueParam)
if (perElementParams && parameter.IsParams &&
parameter.Type.ElementType is not null)
{
sb.Append("(new global::Mockolate.Parameters.ParamsArrayParameterMatch<")
.Append(parameter.Type.ElementType.Fullname).Append(">(").Append(parameter.Name)
.Append(").Matches(__i.Parameter").Append(i + 1).Append("))");
}
else if (isValueParam)
{
sb.Append(
$"(global::System.Collections.Generic.EqualityComparer<{parameter.ToTypeOrWrapper()}>.Default.Equals({parameter.Name}, __i.Parameter{i + 1}))");
parameter.IsParams && parameter.Type.ElementType is not null
? $"(CovariantParameterAdapter<{parameter.Type.Fullname}>.Wrap(global::Mockolate.It.SequenceEquals<{parameter.Type.ElementType.Fullname}>({parameter.Name})).Matches(__i.Parameter{i + 1}))"
: $"(global::System.Collections.Generic.EqualityComparer<{parameter.ToTypeOrWrapper()}>.Default.Equals({parameter.Name}, __i.Parameter{i + 1}))");
}
else if (parameter.RefKind == RefKind.Out || parameter.RefKind == RefKind.Ref ||
parameter.RefKind == RefKind.RefReadOnlyParameter)
Expand Down
5 changes: 5 additions & 0 deletions Source/Mockolate/It.Contains.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ IContainsParameter<T> IContainsParameter<T>.Using(IEqualityComparer<T> comparer,
/// <inheritdoc cref="CollectionMatchCore{T}.MatchesCollection(IEnumerable{T})" />
protected override bool MatchesCollection(IEnumerable<T> value)
{
if (value is null)
{
return false;
}

IEqualityComparer<T> comparer = _comparer ?? EqualityComparer<T>.Default;
return value.Contains(item, comparer);
}
Expand Down
5 changes: 5 additions & 0 deletions Source/Mockolate/It.SequenceEquals.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ ISequenceEqualsParameter<T> ISequenceEqualsParameter<T>.Using(IEqualityComparer<
/// <inheritdoc cref="CollectionMatchCore{T}.MatchesCollection(IEnumerable{T})" />
protected override bool MatchesCollection(IEnumerable<T> value)
{
if (value is null)
{
return false;
}

IEqualityComparer<T> comparer = _comparer ?? EqualityComparer<T>.Default;
return value.SequenceEqual(expected, comparer);
}
Expand Down
Loading
Loading