init
This commit is contained in:
158
internal/chatter/gemini.go
Normal file
158
internal/chatter/gemini.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package chatter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
type GeminiAdapter struct {
|
||||
client *genai.Client
|
||||
model string
|
||||
}
|
||||
|
||||
func NewGeminiAdapter(ctx context.Context, apiKey string, modelName string) (*GeminiAdapter, error) {
|
||||
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gemini client: %w", err)
|
||||
}
|
||||
|
||||
return &GeminiAdapter{
|
||||
client: client,
|
||||
model: modelName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *GeminiAdapter) Generate(ctx context.Context, messages []Message) (string, error) {
|
||||
model := a.client.GenerativeModel(a.model)
|
||||
cs := model.StartChat()
|
||||
|
||||
// Convert history (excluding the last message which is the current prompt)
|
||||
var history []*genai.Content
|
||||
for i := 0; i < len(messages)-1; i++ {
|
||||
history = append(history, &genai.Content{
|
||||
Role: mapRoleToGemini(messages[i].Role),
|
||||
Parts: []genai.Part{genai.Text(messages[i].Content)},
|
||||
})
|
||||
}
|
||||
cs.History = history
|
||||
|
||||
// Send the last message
|
||||
lastMsg := messages[len(messages)-1].Content
|
||||
resp, err := cs.SendMessage(ctx, genai.Text(lastMsg))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gemini generation failed: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Candidates) == 0 || len(resp.Candidates[0].Content.Parts) == 0 {
|
||||
return "", fmt.Errorf("empty response from gemini")
|
||||
}
|
||||
|
||||
// Extract text from the first part of the first candidate
|
||||
if part, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok {
|
||||
return string(part), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unexpected response format")
|
||||
}
|
||||
|
||||
// GenerateStructured reflects the 'target' to create a schema, then unmarshals the result.
|
||||
func (a *GeminiAdapter) GenerateStructured(ctx context.Context, messages []Message, target any) error {
|
||||
t := reflect.TypeOf(target)
|
||||
|
||||
model := a.client.GenerativeModel(a.model)
|
||||
|
||||
// 1. Automatically generate the JSON Schema from the Go struct
|
||||
// 1. Recursively map the Go struct to Gemini's Schema format
|
||||
model.ResponseMIMEType = "application/json"
|
||||
model.ResponseSchema = schemaFromType(t.Elem()) // 2. Convert to the internal genai.Schema format
|
||||
|
||||
var prompt []genai.Part
|
||||
for _, m := range messages {
|
||||
prompt = append(prompt, genai.Text(m.Content))
|
||||
}
|
||||
|
||||
resp, err := model.GenerateContent(ctx, prompt...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
part := resp.Candidates[0].Content.Parts[0].(genai.Text)
|
||||
err = json.Unmarshal([]byte(part), target)
|
||||
return err
|
||||
}
|
||||
|
||||
// schemaFromType recursively builds a *genai.Schema tree from a reflect.Type
|
||||
func schemaFromType(t reflect.Type) *genai.Schema {
|
||||
// Follow pointers to the base type
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Struct:
|
||||
props := make(map[string]*genai.Schema)
|
||||
var required []string
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get JSON tag or use field name
|
||||
name := field.Tag.Get("json")
|
||||
if name == "" || name == "-" {
|
||||
name = field.Name
|
||||
}
|
||||
|
||||
props[name] = schemaFromType(field.Type)
|
||||
required = append(required, name)
|
||||
}
|
||||
|
||||
return &genai.Schema{
|
||||
Type: genai.TypeObject,
|
||||
Properties: props,
|
||||
Required: required,
|
||||
}
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
return &genai.Schema{
|
||||
Type: genai.TypeArray,
|
||||
Items: schemaFromType(t.Elem()),
|
||||
}
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return &genai.Schema{Type: genai.TypeInteger}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return &genai.Schema{Type: genai.TypeNumber}
|
||||
|
||||
case reflect.Bool:
|
||||
return &genai.Schema{Type: genai.TypeBoolean}
|
||||
|
||||
default:
|
||||
return &genai.Schema{Type: genai.TypeString}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *GeminiAdapter) GetProviderName() string {
|
||||
return "Google Gemini (" + a.model + ")"
|
||||
}
|
||||
|
||||
// Helper to map your roles to Gemini's expected roles
|
||||
func mapRoleToGemini(role Role) string {
|
||||
switch role {
|
||||
case RoleUser:
|
||||
return "user"
|
||||
case RoleAssistant:
|
||||
return "model"
|
||||
default:
|
||||
return "user" // Gemini doesn't have a specific 'system' role in chat history
|
||||
}
|
||||
}
|
||||
23
internal/chatter/model.go
Normal file
23
internal/chatter/model.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package chatter
|
||||
|
||||
// Role defines who sent the message
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleSystem Role = "system"
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
)
|
||||
|
||||
// Message represents a single turn in a conversation
|
||||
type Message struct {
|
||||
Role Role
|
||||
Content string
|
||||
}
|
||||
|
||||
// PredictionConfig allows for per-request overrides
|
||||
type PredictionConfig struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
Stop []string
|
||||
}
|
||||
158
internal/chatter/ollama.go
Normal file
158
internal/chatter/ollama.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package chatter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// OllamaAdapter implements the Adapter interface for local Ollama instances.
|
||||
type OllamaAdapter struct {
|
||||
client *api.Client
|
||||
model string
|
||||
}
|
||||
|
||||
// NewOllamaAdapter initializes a new Ollama client.
|
||||
// Default address is usually "http://localhost:11434"
|
||||
func NewOllamaAdapter(endpoint string, model string) (*OllamaAdapter, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ollama client: %w", err)
|
||||
}
|
||||
|
||||
return &OllamaAdapter{
|
||||
client: client,
|
||||
model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate sends the chat history to Ollama and returns the assistant's response.
|
||||
func (a *OllamaAdapter) Generate(ctx context.Context, messages []Message) (string, error) {
|
||||
var ollamaMessages []api.Message
|
||||
|
||||
// Map our internal Message struct to Ollama's API struct
|
||||
for _, m := range messages {
|
||||
ollamaMessages = append(ollamaMessages, api.Message{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
|
||||
req := &api.ChatRequest{
|
||||
Model: a.model,
|
||||
Messages: ollamaMessages,
|
||||
Stream: new(bool), // Set to false for a single string response
|
||||
}
|
||||
|
||||
var response string
|
||||
err := a.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||
fmt.Print(resp.Message)
|
||||
response = resp.Message.Content
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ollama generation failed: %w", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (a *OllamaAdapter) GetProviderName() string {
|
||||
return "Ollama (" + a.model + ")"
|
||||
}
|
||||
|
||||
// GenerateStructured implements the same signature as your Gemini adapter.
|
||||
func (a *OllamaAdapter) GenerateStructured(ctx context.Context, messages []Message, target any) error {
|
||||
val := reflect.ValueOf(target)
|
||||
if val.Kind() != reflect.Ptr || val.Elem().Kind() != reflect.Struct {
|
||||
return fmt.Errorf("target must be a pointer to a struct")
|
||||
}
|
||||
|
||||
// 1. Generate the JSON Schema from the Go struct
|
||||
schema := a.schemaFromType(val.Elem().Type())
|
||||
schemaBytes, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal schema: %w", err)
|
||||
}
|
||||
|
||||
// 2. Map messages
|
||||
var ollamaMessages []api.Message
|
||||
for _, m := range messages {
|
||||
ollamaMessages = append(ollamaMessages, api.Message{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// 3. Set Format to the raw JSON Schema
|
||||
req := &api.ChatRequest{
|
||||
Model: a.model,
|
||||
Messages: ollamaMessages,
|
||||
Format: json.RawMessage(schemaBytes),
|
||||
Stream: new(bool), // false
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0, // Recommended for structured tasks
|
||||
},
|
||||
}
|
||||
|
||||
var responseText string
|
||||
err = a.client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||
responseText = resp.Message.Content
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. Parse the result into the target struct
|
||||
return json.Unmarshal([]byte(responseText), target)
|
||||
}
|
||||
|
||||
// schemaFromType (Reuse the same logic from the Gemini adapter)
|
||||
func (a *OllamaAdapter) schemaFromType(t reflect.Type) map[string]interface{} {
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Struct:
|
||||
props := make(map[string]interface{})
|
||||
var required []string
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
name := field.Tag.Get("json")
|
||||
if name == "" || name == "-" {
|
||||
name = field.Name
|
||||
}
|
||||
props[name] = a.schemaFromType(field.Type)
|
||||
required = append(required, name)
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
return map[string]interface{}{
|
||||
"type": "array",
|
||||
"items": a.schemaFromType(t.Elem()),
|
||||
}
|
||||
|
||||
case reflect.Int, reflect.Int64:
|
||||
return map[string]interface{}{"type": "integer"}
|
||||
case reflect.Float64:
|
||||
return map[string]interface{}{"type": "number"}
|
||||
case reflect.Bool:
|
||||
return map[string]interface{}{"type": "boolean"}
|
||||
default:
|
||||
return map[string]interface{}{"type": "string"}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user