159 lines
3.9 KiB
Go
159 lines
3.9 KiB
Go
package chatter
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
// OllamaAdapter implements the Adapter interface for local Ollama instances.
|
|
type OllamaAdapter struct {
|
|
client *api.Client
|
|
model string
|
|
}
|
|
|
|
// NewOllamaAdapter initializes a new Ollama client.
|
|
// Default address is usually "http://localhost:11434"
|
|
func NewOllamaAdapter(endpoint string, model string) (*OllamaAdapter, error) {
|
|
client, err := api.ClientFromEnvironment()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create ollama client: %w", err)
|
|
}
|
|
|
|
return &OllamaAdapter{
|
|
client: client,
|
|
model: model,
|
|
}, nil
|
|
}
|
|
|
|
// Generate sends the chat history to Ollama and returns the assistant's response.
|
|
func (a *OllamaAdapter) Generate(ctx context.Context, messages []Message) (string, error) {
|
|
var ollamaMessages []api.Message
|
|
|
|
// Map our internal Message struct to Ollama's API struct
|
|
for _, m := range messages {
|
|
ollamaMessages = append(ollamaMessages, api.Message{
|
|
Role: string(m.Role),
|
|
Content: m.Content,
|
|
})
|
|
}
|
|
|
|
req := &api.ChatRequest{
|
|
Model: a.model,
|
|
Messages: ollamaMessages,
|
|
Stream: new(bool), // Set to false for a single string response
|
|
}
|
|
|
|
var response string
|
|
err := a.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
|
fmt.Print(resp.Message)
|
|
response = resp.Message.Content
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("ollama generation failed: %w", err)
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func (a *OllamaAdapter) GetProviderName() string {
|
|
return "Ollama (" + a.model + ")"
|
|
}
|
|
|
|
// GenerateStructured implements the same signature as your Gemini adapter.
|
|
func (a *OllamaAdapter) GenerateStructured(ctx context.Context, messages []Message, target any) error {
|
|
val := reflect.ValueOf(target)
|
|
if val.Kind() != reflect.Ptr || val.Elem().Kind() != reflect.Struct {
|
|
return fmt.Errorf("target must be a pointer to a struct")
|
|
}
|
|
|
|
// 1. Generate the JSON Schema from the Go struct
|
|
schema := a.schemaFromType(val.Elem().Type())
|
|
schemaBytes, err := json.Marshal(schema)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal schema: %w", err)
|
|
}
|
|
|
|
// 2. Map messages
|
|
var ollamaMessages []api.Message
|
|
for _, m := range messages {
|
|
ollamaMessages = append(ollamaMessages, api.Message{
|
|
Role: string(m.Role),
|
|
Content: m.Content,
|
|
})
|
|
}
|
|
|
|
// 3. Set Format to the raw JSON Schema
|
|
req := &api.ChatRequest{
|
|
Model: a.model,
|
|
Messages: ollamaMessages,
|
|
Format: json.RawMessage(schemaBytes),
|
|
Stream: new(bool), // false
|
|
Options: map[string]interface{}{
|
|
"temperature": 0, // Recommended for structured tasks
|
|
},
|
|
}
|
|
|
|
var responseText string
|
|
err = a.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
|
responseText = resp.Message.Content
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 4. Parse the result into the target struct
|
|
return json.Unmarshal([]byte(responseText), target)
|
|
}
|
|
|
|
// schemaFromType (Reuse the same logic from the Gemini adapter)
|
|
func (a *OllamaAdapter) schemaFromType(t reflect.Type) map[string]interface{} {
|
|
for t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
|
|
switch t.Kind() {
|
|
case reflect.Struct:
|
|
props := make(map[string]interface{})
|
|
var required []string
|
|
|
|
for i := 0; i < t.NumField(); i++ {
|
|
field := t.Field(i)
|
|
if !field.IsExported() {
|
|
continue
|
|
}
|
|
name := field.Tag.Get("json")
|
|
if name == "" || name == "-" {
|
|
name = field.Name
|
|
}
|
|
props[name] = a.schemaFromType(field.Type)
|
|
required = append(required, name)
|
|
}
|
|
return map[string]interface{}{
|
|
"type": "object",
|
|
"properties": props,
|
|
"required": required,
|
|
}
|
|
|
|
case reflect.Slice, reflect.Array:
|
|
return map[string]interface{}{
|
|
"type": "array",
|
|
"items": a.schemaFromType(t.Elem()),
|
|
}
|
|
|
|
case reflect.Int, reflect.Int64:
|
|
return map[string]interface{}{"type": "integer"}
|
|
case reflect.Float64:
|
|
return map[string]interface{}{"type": "number"}
|
|
case reflect.Bool:
|
|
return map[string]interface{}{"type": "boolean"}
|
|
default:
|
|
return map[string]interface{}{"type": "string"}
|
|
}
|
|
}
|