185 lines
4.4 KiB
Go
185 lines
4.4 KiB
Go
package chatter
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
"github.com/sashabaranov/go-openai/jsonschema"
|
|
)
|
|
|
|
type OpenAIAdapter struct {
|
|
client *openai.Client
|
|
model string
|
|
}
|
|
|
|
func NewOpenAIAdapter(apiKey string, model string, baseURL string) *OpenAIAdapter {
|
|
config := openai.DefaultConfig(apiKey)
|
|
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true, // Bypasses the "not standards compliant" error
|
|
},
|
|
}
|
|
|
|
config.HTTPClient = &http.Client{Transport: tr}
|
|
|
|
if baseURL != "" {
|
|
config.BaseURL = baseURL
|
|
}
|
|
if baseURL != "" {
|
|
config.BaseURL = baseURL
|
|
}
|
|
return &OpenAIAdapter{
|
|
client: openai.NewClientWithConfig(config),
|
|
model: model,
|
|
}
|
|
}
|
|
|
|
func (a *OpenAIAdapter) Generate(ctx context.Context, messages []Message) (string, error) {
|
|
var chatMsgs []openai.ChatCompletionMessage
|
|
for _, m := range messages {
|
|
chatMsgs = append(chatMsgs, openai.ChatCompletionMessage{
|
|
Role: string(m.Role),
|
|
Content: m.Content,
|
|
})
|
|
}
|
|
|
|
resp, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
|
Model: a.model,
|
|
Messages: chatMsgs,
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return resp.Choices[0].Message.Content, nil
|
|
}
|
|
|
|
func (a *OpenAIAdapter) GenerateStructured(ctx context.Context, messages []Message, target any) error {
|
|
val := reflect.ValueOf(target)
|
|
if val.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("target must be a pointer")
|
|
}
|
|
|
|
elem := val.Elem()
|
|
var schemaType reflect.Type
|
|
isSlice := elem.Kind() == reflect.Slice
|
|
|
|
// 1. Wrap slices in an object because OpenAI requires a root object
|
|
if isSlice {
|
|
schemaType = reflect.StructOf([]reflect.StructField{
|
|
{
|
|
Name: "Items",
|
|
Type: elem.Type(),
|
|
Tag: `json:"items"`,
|
|
},
|
|
})
|
|
} else {
|
|
schemaType = elem.Type()
|
|
}
|
|
|
|
// 2. Build the Schema Map
|
|
schemaObj := a.reflectTypeToSchema(schemaType)
|
|
|
|
// 3. Convert to json.RawMessage to satisfy the json.Marshaler interface
|
|
schemaBytes, err := json.Marshal(schemaObj)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal schema: %w", err)
|
|
}
|
|
|
|
var chatMsgs []openai.ChatCompletionMessage
|
|
for _, m := range messages {
|
|
chatMsgs = append(chatMsgs, openai.ChatCompletionMessage{
|
|
Role: string(m.Role),
|
|
Content: m.Content,
|
|
})
|
|
}
|
|
|
|
// 4. Send Request
|
|
req := openai.ChatCompletionRequest{
|
|
Model: a.model,
|
|
Messages: chatMsgs,
|
|
ResponseFormat: &openai.ChatCompletionResponseFormat{
|
|
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
|
|
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
|
|
Name: "output_schema",
|
|
Strict: true,
|
|
Schema: json.RawMessage(schemaBytes),
|
|
},
|
|
},
|
|
}
|
|
|
|
resp, err := a.client.CreateChatCompletion(ctx, req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
content := resp.Choices[0].Message.Content
|
|
|
|
// 5. Unmarshal and Unwrap if necessary
|
|
if isSlice {
|
|
temp := struct {
|
|
Items json.RawMessage `json:"items"`
|
|
}{}
|
|
if err := json.Unmarshal([]byte(content), &temp); err != nil {
|
|
return err
|
|
}
|
|
return json.Unmarshal(temp.Items, target)
|
|
}
|
|
|
|
return json.Unmarshal([]byte(content), target)
|
|
}
|
|
|
|
func (a *OpenAIAdapter) reflectTypeToSchema(t reflect.Type) jsonschema.Definition {
|
|
for t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
|
|
switch t.Kind() {
|
|
case reflect.Struct:
|
|
def := jsonschema.Definition{
|
|
Type: jsonschema.Object,
|
|
Properties: make(map[string]jsonschema.Definition),
|
|
AdditionalProperties: false,
|
|
Required: []string{},
|
|
}
|
|
for i := 0; i < t.NumField(); i++ {
|
|
field := t.Field(i)
|
|
if !field.IsExported() {
|
|
continue
|
|
}
|
|
name := field.Tag.Get("json")
|
|
if name == "" || name == "-" {
|
|
name = field.Name
|
|
}
|
|
def.Properties[name] = a.reflectTypeToSchema(field.Type)
|
|
def.Required = append(def.Required, name)
|
|
}
|
|
return def
|
|
|
|
case reflect.Slice, reflect.Array:
|
|
items := a.reflectTypeToSchema(t.Elem())
|
|
return jsonschema.Definition{
|
|
Type: jsonschema.Array,
|
|
Items: &items,
|
|
}
|
|
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return jsonschema.Definition{Type: jsonschema.Integer}
|
|
case reflect.Float32, reflect.Float64:
|
|
return jsonschema.Definition{Type: jsonschema.Number}
|
|
case reflect.Bool:
|
|
return jsonschema.Definition{Type: jsonschema.Boolean}
|
|
default:
|
|
return jsonschema.Definition{Type: jsonschema.String}
|
|
}
|
|
}
|
|
|
|
func (a *OpenAIAdapter) GetProviderName() string {
|
|
return "OpenAI (" + a.model + ")"
|
|
}
|