init
This commit is contained in:
158
internal/chatter/ollama.go
Normal file
158
internal/chatter/ollama.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package chatter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// OllamaAdapter implements the Adapter interface for local Ollama instances.
|
||||
type OllamaAdapter struct {
|
||||
client *api.Client
|
||||
model string
|
||||
}
|
||||
|
||||
// NewOllamaAdapter initializes a new Ollama client.
|
||||
// Default address is usually "http://localhost:11434"
|
||||
func NewOllamaAdapter(endpoint string, model string) (*OllamaAdapter, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ollama client: %w", err)
|
||||
}
|
||||
|
||||
return &OllamaAdapter{
|
||||
client: client,
|
||||
model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate sends the chat history to Ollama and returns the assistant's response.
|
||||
func (a *OllamaAdapter) Generate(ctx context.Context, messages []Message) (string, error) {
|
||||
var ollamaMessages []api.Message
|
||||
|
||||
// Map our internal Message struct to Ollama's API struct
|
||||
for _, m := range messages {
|
||||
ollamaMessages = append(ollamaMessages, api.Message{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: a.model,
|
||||
Messages: ollamaMessages,
|
||||
Stream: new(bool), // Set to false for a single string response
|
||||
}
|
||||
|
||||
var response string
|
||||
err := a.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||
fmt.Print(resp.Message)
|
||||
response = resp.Message.Content
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ollama generation failed: %w", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (a *OllamaAdapter) GetProviderName() string {
|
||||
return "Ollama (" + a.model + ")"
|
||||
}
|
||||
|
||||
// GenerateStructured implements the same signature as your Gemini adapter.
|
||||
func (a *OllamaAdapter) GenerateStructured(ctx context.Context, messages []Message, target any) error {
|
||||
val := reflect.ValueOf(target)
|
||||
if val.Kind() != reflect.Ptr || val.Elem().Kind() != reflect.Struct {
|
||||
return fmt.Errorf("target must be a pointer to a struct")
|
||||
}
|
||||
|
||||
// 1. Generate the JSON Schema from the Go struct
|
||||
schema := a.schemaFromType(val.Elem().Type())
|
||||
schemaBytes, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal schema: %w", err)
|
||||
}
|
||||
|
||||
// 2. Map messages
|
||||
var ollamaMessages []api.Message
|
||||
for _, m := range messages {
|
||||
ollamaMessages = append(ollamaMessages, api.Message{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// 3. Set Format to the raw JSON Schema
|
||||
req := &api.ChatRequest{
|
||||
Model: a.model,
|
||||
Messages: ollamaMessages,
|
||||
Format: json.RawMessage(schemaBytes),
|
||||
Stream: new(bool), // false
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0, // Recommended for structured tasks
|
||||
},
|
||||
}
|
||||
|
||||
var responseText string
|
||||
err = a.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||
responseText = resp.Message.Content
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. Parse the result into the target struct
|
||||
return json.Unmarshal([]byte(responseText), target)
|
||||
}
|
||||
|
||||
// schemaFromType (Reuse the same logic from the Gemini adapter)
|
||||
func (a *OllamaAdapter) schemaFromType(t reflect.Type) map[string]interface{} {
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Struct:
|
||||
props := make(map[string]interface{})
|
||||
var required []string
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
name := field.Tag.Get("json")
|
||||
if name == "" || name == "-" {
|
||||
name = field.Name
|
||||
}
|
||||
props[name] = a.schemaFromType(field.Type)
|
||||
required = append(required, name)
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
return map[string]interface{}{
|
||||
"type": "array",
|
||||
"items": a.schemaFromType(t.Elem()),
|
||||
}
|
||||
|
||||
case reflect.Int, reflect.Int64:
|
||||
return map[string]interface{}{"type": "integer"}
|
||||
case reflect.Float64:
|
||||
return map[string]interface{}{"type": "number"}
|
||||
case reflect.Bool:
|
||||
return map[string]interface{}{"type": "boolean"}
|
||||
default:
|
||||
return map[string]interface{}{"type": "string"}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user