Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protoc-gen-openapi: Select the correct schemas that correspond to the messages used #414

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions cmd/protoc-gen-openapi/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,15 @@ func (g *OpenAPIv3Generator) buildDocumentV3() *v3.Document {
// While we have required schemas left to generate, go through the files again
// looking for the related message and adding them to the document if required.
for len(g.reflect.requiredSchemas) > 0 {
count := len(g.reflect.requiredSchemas)
for _, file := range g.plugin.Files {
g.addSchemasForMessagesToDocumentV3(d, file.Messages)
}
g.reflect.requiredSchemas = g.reflect.requiredSchemas[count:len(g.reflect.requiredSchemas)]
// clear the generated schemas
for schema := range g.reflect.requiredSchemas {
if contains(g.generatedSchemas, schema) {
delete(g.reflect.requiredSchemas, schema)
}
}
}

// If there is only 1 service, then use it's title for the
Expand Down Expand Up @@ -771,12 +775,14 @@ func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*pr
}
}

// addSchemaForMessageToDocumentV3 adds the schema to the document if required
// addSchemaToDocumentV3 adds the schema to the document if required
func (g *OpenAPIv3Generator) addSchemaToDocumentV3(d *v3.Document, schema *v3.NamedSchemaOrReference) {
if contains(g.generatedSchemas, schema.Name) {
return
// check if schema already exists in Schemas, instead of checking "generated"
for _, prop := range d.Components.Schemas.AdditionalProperties {
if prop.Name == schema.Name {
return
}
}
g.generatedSchemas = append(g.generatedSchemas, schema.Name)
d.Components.Schemas.AdditionalProperties = append(d.Components.Schemas.AdditionalProperties, schema)
}

Expand All @@ -789,12 +795,25 @@ func (g *OpenAPIv3Generator) addSchemasForMessagesToDocumentV3(d *v3.Document, m
}

schemaName := g.reflect.formatMessageName(message.Desc)
fqSchemaName := g.reflect.formatPackageMessageName(message.Desc)

// Only generate this if we need it and haven't already generated it.
if !contains(g.reflect.requiredSchemas, schemaName) ||
contains(g.generatedSchemas, schemaName) {
requiredFQSchema, ok := g.reflect.requiredSchemas[schemaName]
if !ok {
continue
} else if requiredFQSchema != fqSchemaName {
// "schemaName" with same name is required, but it's not the actual
// schema with "fqSchemaName". Try to use the fully-qualified schema.
if _, ok = g.reflect.requiredSchemas[fqSchemaName]; !ok {
continue
}
// use fully-qualified name as schema name if there are same named messages
schemaName = fqSchemaName
}
if contains(g.generatedSchemas, schemaName) {
continue
}
g.generatedSchemas = append(g.generatedSchemas, schemaName)

typeName := g.reflect.fullMessageTypeName(message.Desc)
messageDescription := g.filterCommentString(message.Comments.Leading)
Expand Down
25 changes: 21 additions & 4 deletions cmd/protoc-gen-openapi/generator/reflector.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ const (
type OpenAPIv3Reflector struct {
conf Configuration

requiredSchemas []string // Names of schemas which are used through references.
// Names of schemas which are used through references.
// map: schema name will be used actually -> fully-qualified schema name
requiredSchemas map[string]string
}

// NewOpenAPIv3Reflector creates a new reflector.
func NewOpenAPIv3Reflector(conf Configuration) *OpenAPIv3Reflector {
return &OpenAPIv3Reflector{
conf: conf,

requiredSchemas: make([]string, 0),
requiredSchemas: make(map[string]string, 0),
}
}

Expand Down Expand Up @@ -86,6 +88,14 @@ func (r *OpenAPIv3Reflector) formatMessageName(message protoreflect.MessageDescr
return name
}

// formatPackageMessageName returns the fully-qualified name of a message.
func (r *OpenAPIv3Reflector) formatPackageMessageName(message protoreflect.MessageDescriptor) string {
package_name := string(message.ParentFile().Package())
name := package_name + "." + r.getMessageName(message)

return name
}

func (r *OpenAPIv3Reflector) formatFieldName(field protoreflect.FieldDescriptor) string {
if *r.conf.Naming == "proto" {
return string(field.Name())
Expand Down Expand Up @@ -116,8 +126,15 @@ func (r *OpenAPIv3Reflector) responseContentForMessage(message protoreflect.Mess

func (r *OpenAPIv3Reflector) schemaReferenceForMessage(message protoreflect.MessageDescriptor) string {
schemaName := r.formatMessageName(message)
if !contains(r.requiredSchemas, schemaName) {
r.requiredSchemas = append(r.requiredSchemas, schemaName)
fqSchemaName := r.formatPackageMessageName(message)
requiredFQSchema, ok := r.requiredSchemas[schemaName]
if !ok {
// new required, use schemaName
r.requiredSchemas[schemaName] = fqSchemaName
} else if requiredFQSchema != fqSchemaName {
// use the fully-qualified schema name as there are same named messages
schemaName = fqSchemaName
r.requiredSchemas[schemaName] = fqSchemaName
}
return "#/components/schemas/" + schemaName
}
Expand Down