feat: correctly implement bitbucket & add OpenAIAdapter
This commit is contained in:
184
internal/chatter/openai.go
Normal file
184
internal/chatter/openai.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package chatter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
type OpenAIAdapter struct {
|
||||
client *openai.Client
|
||||
model string
|
||||
}
|
||||
|
||||
func NewOpenAIAdapter(apiKey string, model string, baseURL string) *OpenAIAdapter {
|
||||
config := openai.DefaultConfig(apiKey)
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true, // Bypasses the "not standards compliant" error
|
||||
},
|
||||
}
|
||||
|
||||
config.HTTPClient = &http.Client{Transport: tr}
|
||||
|
||||
if baseURL != "" {
|
||||
config.BaseURL = baseURL
|
||||
}
|
||||
if baseURL != "" {
|
||||
config.BaseURL = baseURL
|
||||
}
|
||||
return &OpenAIAdapter{
|
||||
client: openai.NewClientWithConfig(config),
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *OpenAIAdapter) Generate(ctx context.Context, messages []Message) (string, error) {
|
||||
var chatMsgs []openai.ChatCompletionMessage
|
||||
for _, m := range messages {
|
||||
chatMsgs = append(chatMsgs, openai.ChatCompletionMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
|
||||
resp, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
||||
Model: a.model,
|
||||
Messages: chatMsgs,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func (a *OpenAIAdapter) GenerateStructured(ctx context.Context, messages []Message, target any) error {
|
||||
val := reflect.ValueOf(target)
|
||||
if val.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("target must be a pointer")
|
||||
}
|
||||
|
||||
elem := val.Elem()
|
||||
var schemaType reflect.Type
|
||||
isSlice := elem.Kind() == reflect.Slice
|
||||
|
||||
// 1. Wrap slices in an object because OpenAI requires a root object
|
||||
if isSlice {
|
||||
schemaType = reflect.StructOf([]reflect.StructField{
|
||||
{
|
||||
Name: "Items",
|
||||
Type: elem.Type(),
|
||||
Tag: `json:"items"`,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
schemaType = elem.Type()
|
||||
}
|
||||
|
||||
// 2. Build the Schema Map
|
||||
schemaObj := a.reflectTypeToSchema(schemaType)
|
||||
|
||||
// 3. Convert to json.RawMessage to satisfy the json.Marshaler interface
|
||||
schemaBytes, err := json.Marshal(schemaObj)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal schema: %w", err)
|
||||
}
|
||||
|
||||
var chatMsgs []openai.ChatCompletionMessage
|
||||
for _, m := range messages {
|
||||
chatMsgs = append(chatMsgs, openai.ChatCompletionMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// 4. Send Request
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: a.model,
|
||||
Messages: chatMsgs,
|
||||
ResponseFormat: &openai.ChatCompletionResponseFormat{
|
||||
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
|
||||
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
|
||||
Name: "output_schema",
|
||||
Strict: true,
|
||||
Schema: json.RawMessage(schemaBytes),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := a.client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content := resp.Choices[0].Message.Content
|
||||
|
||||
// 5. Unmarshal and Unwrap if necessary
|
||||
if isSlice {
|
||||
temp := struct {
|
||||
Items json.RawMessage `json:"items"`
|
||||
}{}
|
||||
if err := json.Unmarshal([]byte(content), &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(temp.Items, target)
|
||||
}
|
||||
|
||||
return json.Unmarshal([]byte(content), target)
|
||||
}
|
||||
|
||||
func (a *OpenAIAdapter) reflectTypeToSchema(t reflect.Type) jsonschema.Definition {
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Struct:
|
||||
def := jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
Properties: make(map[string]jsonschema.Definition),
|
||||
AdditionalProperties: false,
|
||||
Required: []string{},
|
||||
}
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
name := field.Tag.Get("json")
|
||||
if name == "" || name == "-" {
|
||||
name = field.Name
|
||||
}
|
||||
def.Properties[name] = a.reflectTypeToSchema(field.Type)
|
||||
def.Required = append(def.Required, name)
|
||||
}
|
||||
return def
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
items := a.reflectTypeToSchema(t.Elem())
|
||||
return jsonschema.Definition{
|
||||
Type: jsonschema.Array,
|
||||
Items: &items,
|
||||
}
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return jsonschema.Definition{Type: jsonschema.Integer}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return jsonschema.Definition{Type: jsonschema.Number}
|
||||
case reflect.Bool:
|
||||
return jsonschema.Definition{Type: jsonschema.Boolean}
|
||||
default:
|
||||
return jsonschema.Definition{Type: jsonschema.String}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *OpenAIAdapter) GetProviderName() string {
|
||||
return "OpenAI (" + a.model + ")"
|
||||
}
|
||||
Reference in New Issue
Block a user