Skip to content

Commit

Permalink
Adding response schema
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerBarreto committed Jan 8, 2025
1 parent 5b97ad1 commit 12a6089
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 36 deletions.
10 changes: 10 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ dotnet_diagnostic.IDE0005.severity = warning # Remove unnecessary using directiv
dotnet_diagnostic.IDE0009.severity = warning # Add this or Me qualification
dotnet_diagnostic.IDE0011.severity = warning # Add braces
dotnet_diagnostic.IDE0018.severity = warning # Inline variable declaration

dotnet_diagnostic.IDE0032.severity = warning # Use auto-implemented property
dotnet_diagnostic.IDE0034.severity = warning # Simplify 'default' expression
dotnet_diagnostic.IDE0035.severity = warning # Remove unreachable code
Expand Down Expand Up @@ -221,20 +222,29 @@ dotnet_diagnostic.RCS1241.severity = none # Implement IComparable when implement
dotnet_diagnostic.IDE0001.severity = none # Simplify name
dotnet_diagnostic.IDE0002.severity = none # Simplify member access
dotnet_diagnostic.IDE0004.severity = none # Remove unnecessary cast
dotnet_diagnostic.IDE0010.severity = none # Populate switch
dotnet_diagnostic.IDE0021.severity = none # Use block body for constructors
dotnet_diagnostic.IDE0022.severity = none # Use block body for methods
dotnet_diagnostic.IDE0024.severity = none # Use block body for operator
dotnet_diagnostic.IDE0035.severity = none # Remove unreachable code
dotnet_diagnostic.IDE0051.severity = none # Remove unused private member
dotnet_diagnostic.IDE0052.severity = none # Remove unread private member
dotnet_diagnostic.IDE0058.severity = none # Remove unused expression value
dotnet_diagnostic.IDE0059.severity = none # Unnecessary assignment of a value
dotnet_diagnostic.IDE0060.severity = none # Remove unused parameter
dotnet_diagnostic.IDE0061.severity = none # Use block body for local function
dotnet_diagnostic.IDE0079.severity = none # Remove unnecessary suppression.
dotnet_diagnostic.IDE0080.severity = none # Remove unnecessary suppression operator.
dotnet_diagnostic.IDE0100.severity = none # Remove unnecessary equality operator
dotnet_diagnostic.IDE0110.severity = none # Remove unnecessary discards
dotnet_diagnostic.IDE0130.severity = none # Namespace does not match folder structure
dotnet_diagnostic.IDE0290.severity = none # Use primary constructor
dotnet_diagnostic.IDE0032.severity = none # Use auto property
dotnet_diagnostic.IDE0160.severity = none # Use block-scoped namespace
dotnet_diagnostic.IDE1006.severity = warning # Naming rule violations
dotnet_diagnostic.IDE0046.severity = suggestion # If statement can be simplified
dotnet_diagnostic.IDE0056.severity = suggestion # Indexing can be simplified
dotnet_diagnostic.IDE0057.severity = suggestion # Substring can be simplified

###############################
# Naming Conventions #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync()
// Arrange
var bearerTokenGenerator = new BearerTokenGenerator()
{
BearerKeys = new List<string> { "key1", "key2", "key3" }
BearerKeys = ["key1", "key2", "key3"]
};

var responseContent = File.ReadAllText(ChatTestDataFilePath);
Expand All @@ -442,7 +442,7 @@ public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync()
httpClient: httpClient,
modelId: "fake-model",
apiVersion: VertexAIVersion.V1,
bearerTokenProvider: () => bearerTokenGenerator.GetBearerToken(),
bearerTokenProvider: bearerTokenGenerator.GetBearerToken,
location: "fake-location",
projectId: "fake-project-id");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,27 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Connectors.Google.Core;

internal sealed class GeminiRequest
{
private static JsonSerializerOptions? s_options;
private static readonly AIJsonSchemaCreateOptions s_schemaOptions = new()
{
IncludeSchemaKeyword = false,
IncludeTypeInEnumSchemas = true,
RequireAllProperties = false,
DisallowAdditionalProperties = false,
};

[JsonPropertyName("contents")]
public IList<GeminiContent> Contents { get; set; } = null!;

Expand Down Expand Up @@ -249,10 +261,56 @@ private static void AddConfiguration(GeminiPromptExecutionSettings executionSett
StopSequences = executionSettings.StopSequences,
CandidateCount = executionSettings.CandidateCount,
AudioTimestamp = executionSettings.AudioTimestamp,
ResponseMimeType = executionSettings.ResponseMimeType
ResponseMimeType = executionSettings.ResponseMimeType,
ResponseSchema = GetResponseSchemaConfig(executionSettings.ResponseSchema)
};
}

private static JsonElement? GetResponseSchemaConfig(object? responseSchemaSettings)
{
if (responseSchemaSettings is null)
{
return null;
}

if (responseSchemaSettings is JsonElement jsonElement)
{
return jsonElement;
}

return responseSchemaSettings is Type type
? CreateSchema(type, GetDefaultOptions())
: CreateSchema(responseSchemaSettings.GetType(), GetDefaultOptions());
}

private static JsonElement CreateSchema(
Type type,
JsonSerializerOptions options,
string? description = null,
AIJsonSchemaCreateOptions? configuration = null)
{
configuration ??= s_schemaOptions;
return AIJsonUtilities.CreateJsonSchema(type, description, serializerOptions: options, inferenceOptions: configuration);
}

[RequiresUnreferencedCode("Uses JsonStringEnumConverter and DefaultJsonTypeInfoResolver classes, making it incompatible with AOT scenarios.")]
[RequiresDynamicCode("Uses JsonStringEnumConverter and DefaultJsonTypeInfoResolver classes, making it incompatible with AOT scenarios.")]
private static JsonSerializerOptions GetDefaultOptions()
{
if (s_options is null)
{
JsonSerializerOptions options = new()
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
};
options.MakeReadOnly();
s_options = options;
}

return s_options;
}

private static void AddSafetySettings(GeminiPromptExecutionSettings executionSettings, GeminiRequest request)
{
request.SafetySettings = executionSettings.SafetySettings?.Select(s
Expand Down Expand Up @@ -292,5 +350,9 @@ internal sealed class ConfigurationElement
[JsonPropertyName("responseMimeType")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? ResponseMimeType { get; set; }

[JsonPropertyName("responseSchema")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public JsonElement? ResponseSchema { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings
private IList<string>? _stopSequences;
private bool? _audioTimestamp;
private string? _responseMimeType;
private object? _responseSchema;
private IList<GeminiSafetySetting>? _safetySettings;
private GeminiToolCallBehavior? _toolCallBehavior;

Expand Down Expand Up @@ -206,6 +207,29 @@ public string? ResponseMimeType
}
}

/// <summary>
/// Optional. Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be objects, primitives or arrays.
/// If set, a compatible responseMimeType must also be set. Compatible MIME types: application/json: Schema for JSON response.
/// Refer to the https://ai.google.dev/gemini-api/docs/json-mode for more information.
/// </summary>
/// <remarks>
/// Possible values are:
/// <para>- <see cref="object"/> object, which type will be used to automatically create a JSON schema;</para>
/// <para>- <see cref="Type"/> object, which will be used to automatically create a JSON schema.</para>
/// </remarks>
[JsonPropertyName("response_schema")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public object? ResponseSchema
{
get => this._responseSchema;

set
{
this.ThrowIfFrozen();
this._responseSchema = value;
}
}

/// <inheritdoc />
public override void Freeze()
{
Expand Down Expand Up @@ -243,7 +267,8 @@ public override PromptExecutionSettings Clone()
SafetySettings = this.SafetySettings?.Select(setting => new GeminiSafetySetting(setting)).ToList(),
ToolCallBehavior = this.ToolCallBehavior?.Clone(),
AudioTimestamp = this.AudioTimestamp,
ResponseMimeType = this.ResponseMimeType
ResponseMimeType = this.ResponseMimeType,
ResponseSchema = this.ResponseSchema,
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,8 @@ public VectorStoreRecordPropertyReader(
this._parameterlessConstructorInfo = new Lazy<ConstructorInfo>(() =>
{
var constructor = dataModelType.GetConstructor(Type.EmptyTypes);
if (constructor == null)
{
throw new ArgumentException($"Type {dataModelType.FullName} must have a parameterless constructor.");
}

return constructor;
return constructor
?? throw new ArgumentException($"Type {dataModelType.FullName} must have a parameterless constructor.");
});

this._keyPropertyStoragePropertyNames = new Lazy<List<string>>(() =>
Expand Down Expand Up @@ -411,9 +407,9 @@ private static (List<VectorStoreRecordKeyProperty> KeyProperties, List<VectorSto
/// <returns>The categorized properties.</returns>
private static (List<PropertyInfo> KeyProperties, List<PropertyInfo> DataProperties, List<PropertyInfo> VectorProperties) FindPropertiesInfo([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] Type type)
{
List<PropertyInfo> keyProperties = new();
List<PropertyInfo> dataProperties = new();
List<PropertyInfo> vectorProperties = new();
List<PropertyInfo> keyProperties = [];
List<PropertyInfo> dataProperties = [];
List<PropertyInfo> vectorProperties = [];

foreach (var property in type.GetProperties())
{
Expand Down Expand Up @@ -449,42 +445,33 @@ private static (List<PropertyInfo> KeyProperties, List<PropertyInfo> DataPropert
/// <returns>The categorized properties.</returns>
public static (List<PropertyInfo> KeyProperties, List<PropertyInfo> DataProperties, List<PropertyInfo> VectorProperties) FindPropertiesInfo([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] Type type, VectorStoreRecordDefinition vectorStoreRecordDefinition)
{
List<PropertyInfo> keyProperties = new();
List<PropertyInfo> dataProperties = new();
List<PropertyInfo> vectorProperties = new();
List<PropertyInfo> keyProperties = [];
List<PropertyInfo> dataProperties = [];
List<PropertyInfo> vectorProperties = [];

foreach (VectorStoreRecordProperty property in vectorStoreRecordDefinition.Properties)
{
// Key.
if (property is VectorStoreRecordKeyProperty keyPropertyInfo)
{
var keyProperty = type.GetProperty(keyPropertyInfo.DataModelPropertyName);
if (keyProperty == null)
{
throw new ArgumentException($"Key property '{keyPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}.");
}
var keyProperty = type.GetProperty(keyPropertyInfo.DataModelPropertyName)
?? throw new ArgumentException($"Key property '{keyPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}.");

keyProperties.Add(keyProperty);
}
// Data.
else if (property is VectorStoreRecordDataProperty dataPropertyInfo)
{
var dataProperty = type.GetProperty(dataPropertyInfo.DataModelPropertyName);
if (dataProperty == null)
{
throw new ArgumentException($"Data property '{dataPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}.");
}
var dataProperty = type.GetProperty(dataPropertyInfo.DataModelPropertyName)
?? throw new ArgumentException($"Data property '{dataPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}.");

dataProperties.Add(dataProperty);
}
// Vector.
else if (property is VectorStoreRecordVectorProperty vectorPropertyInfo)
{
var vectorProperty = type.GetProperty(vectorPropertyInfo.DataModelPropertyName);
if (vectorProperty == null)
{
throw new ArgumentException($"Vector property '{vectorPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}.");
}
var vectorProperty = type.GetProperty(vectorPropertyInfo.DataModelPropertyName)
?? throw new ArgumentException($"Vector property '{vectorPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}.");

vectorProperties.Add(vectorProperty);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public static Activity AttachSensitiveDataAsEvent(this Activity activity, string
{
activity.AddEvent(new ActivityEvent(
name,
tags: new ActivityTagsCollection(tags)
tags: [.. tags]
));

return activity;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ internal sealed class NonNullCollection<T> : IList<T>, IReadOnlyList<T>
public NonNullCollection(IEnumerable<T> items)
{
Verify.NotNull(items);
this._items = new(items);
this._items = [.. items];
}

/// <summary>
Expand Down
6 changes: 3 additions & 3 deletions dotnet/src/InternalUtilities/src/Text/DataUriParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ internal static class DataUriParser
{
private const string Scheme = "data:";

private static readonly char[] s_base64Chars = {
private static readonly char[] s_base64Chars = [
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'
};
];
/// <summary>
/// Extension method to test whether the value is a base64 string
/// </summary>
Expand Down Expand Up @@ -157,7 +157,7 @@ internal sealed class DataUri
/// <summary>
/// The optional parameters of the data.
/// </summary>
internal Dictionary<string, string> Parameters { get; set; } = new();
internal Dictionary<string, string> Parameters { get; set; } = [];

/// <summary>
/// The optional format of the data. Most common is "base64".
Expand Down

0 comments on commit 12a6089

Please sign in to comment.