// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // For the following go:generate line to work, install the protoveener tool: // git clone https://github.com/googleapis/google-cloud-go // cd google-cloud-go // go install ./internal/protoveneer/cmd/protoveneer // //go:generate ./generate.sh package genai import ( "context" "errors" "fmt" "io" "reflect" "strings" gl "cloud.google.com/go/ai/generativelanguage/apiv1beta" pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb" "github.com/google/generative-ai-go/genai/internal" gld "github.com/google/generative-ai-go/genai/internal/generativelanguage/v1beta" // discovery client "google.golang.org/api/iterator" "google.golang.org/api/option" ) // A Client is a Google generative AI client. type Client struct { gc *gl.GenerativeClient mc *gl.ModelClient fc *gl.FileClient cc *gl.CacheClient ds *gld.Service } // NewClient creates a new Google generative AI client. // // Clients should be reused instead of created as needed. The methods of Client // are safe for concurrent use by multiple goroutines. // // You may configure the client by passing in options from the [google.golang.org/api/option] // package. func NewClient(ctx context.Context, opts ...option.ClientOption) (*Client, error) { if !hasAuthOption(opts) { return nil, errors.New(`You need an auth option to use this client. for an API Key: Visit https://ai.google.dev to get one, put it in an environment variable like GEMINI_API_KEY, then pass it as an option: genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY"))) (If you're doing that already, then maybe the environment variable is empty or unset.) Import the option package as "google.golang.org/api/option".`) } gc, err := gl.NewGenerativeRESTClient(ctx, opts...) if err != nil { return nil, fmt.Errorf("creating generative client: %w", err) } mc, err := gl.NewModelRESTClient(ctx, opts...) if err != nil { return nil, fmt.Errorf("creating model client: %w", err) } fc, err := gl.NewFileRESTClient(ctx, opts...) if err != nil { return nil, fmt.Errorf("creating file client: %w", err) } // Workaround for https://github.com/google/generative-ai-go/issues/151 optsForCache := removeHTTPClientOption(opts) cc, err := gl.NewCacheClient(ctx, optsForCache...) if err != nil { return nil, fmt.Errorf("creating cache client: %w", err) } ds, err := gld.NewService(ctx, opts...) if err != nil { return nil, fmt.Errorf("creating discovery client: %w", err) } kvs := []string{"gccl", "v" + internal.Version, "genai-go", internal.Version} if a, ok := optionOfType[*clientInfo](opts); ok { kvs = append(kvs, a.key, a.value) } gc.SetGoogleClientInfo(kvs...) mc.SetGoogleClientInfo(kvs...) fc.SetGoogleClientInfo(kvs...) return &Client{gc, mc, fc, cc, ds}, nil } // hasAuthOption reports whether an authentication-related option was provided. // // There is no good way to make these checks, because the types of the options // are unexported, and the struct that they populates is in an internal package. func hasAuthOption(opts []option.ClientOption) bool { for _, opt := range opts { v := reflect.ValueOf(opt) ts := v.Type().String() switch ts { case "option.withAPIKey": return v.String() != "" case "option.withHTTPClient", "option.withTokenSource", "option.withCredFile", "option.withCredentialsJSON": return true } } return false } // removeHTTPClientOption removes option.withHTTPClient from the given list // of options, if it exists; it returns the new (filtered) list. func removeHTTPClientOption(opts []option.ClientOption) []option.ClientOption { var newOpts []option.ClientOption for _, opt := range opts { ts := reflect.ValueOf(opt).Type().String() if ts != "option.withHTTPClient" { newOpts = append(newOpts, opt) } } return newOpts } // Close closes the client. func (c *Client) Close() error { return errors.Join(c.gc.Close(), c.mc.Close(), c.fc.Close()) } // GenerativeModel is a model that can generate text. // Create one with [Client.GenerativeModel], then configure // it by setting the exported fields. type GenerativeModel struct { c *Client fullName string GenerationConfig SafetySettings []*SafetySetting Tools []*Tool ToolConfig *ToolConfig // configuration for tools // SystemInstruction (also known as "system prompt") is a more forceful prompt to the model. // The model will adhere the instructions more strongly than if they appeared in a normal prompt. SystemInstruction *Content // The name of the CachedContent to use. // Must have already been created with [Client.CreateCachedContent]. CachedContentName string } // GenerativeModel creates a new instance of the named generative model. // For instance, "gemini-1.0-pro" or "models/gemini-1.0-pro". // // To access a tuned model named NAME, pass "tunedModels/NAME". func (c *Client) GenerativeModel(name string) *GenerativeModel { return &GenerativeModel{ c: c, fullName: fullModelName(name), } } func fullModelName(name string) string { if strings.ContainsRune(name, '/') { return name } return "models/" + name } // GenerateContent produces a single request and response. func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) { content := NewUserContent(parts...) req, err := m.newGenerateContentRequest(content) if err != nil { return nil, err } res, err := m.c.gc.GenerateContent(ctx, req) if err != nil { return nil, err } return protoToResponse(res) } // GenerateContentStream returns an iterator that enumerates responses. func (m *GenerativeModel) GenerateContentStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator { iter := &GenerateContentResponseIterator{} req, err := m.newGenerateContentRequest(NewUserContent(parts...)) if err != nil { iter.err = err } else { iter.sc, iter.err = m.c.gc.StreamGenerateContent(ctx, req) } return iter } func (m *GenerativeModel) generateContent(ctx context.Context, req *pb.GenerateContentRequest) (*GenerateContentResponse, error) { streamClient, err := m.c.gc.StreamGenerateContent(ctx, req) iter := &GenerateContentResponseIterator{ sc: streamClient, err: err, } for { _, err := iter.Next() if err == iterator.Done { return iter.MergedResponse(), nil } if err != nil { return nil, err } } } func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.GenerateContentRequest, error) { return pvCatchPanic(func() *pb.GenerateContentRequest { var cc *string if m.CachedContentName != "" { cc = &m.CachedContentName } req := &pb.GenerateContentRequest{ Model: m.fullName, Contents: transformSlice(contents, (*Content).toProto), SafetySettings: transformSlice(m.SafetySettings, (*SafetySetting).toProto), Tools: transformSlice(m.Tools, (*Tool).toProto), ToolConfig: m.ToolConfig.toProto(), GenerationConfig: m.GenerationConfig.toProto(), SystemInstruction: m.SystemInstruction.toProto(), CachedContent: cc, } debugPrint(req) return req }) } // GenerateContentResponseIterator is an iterator over GnerateContentResponse. type GenerateContentResponseIterator struct { sc pb.GenerativeService_StreamGenerateContentClient err error merged *GenerateContentResponse cs *ChatSession } // Next returns the next response. func (iter *GenerateContentResponseIterator) Next() (*GenerateContentResponse, error) { if iter.err != nil { return nil, iter.err } resp, err := iter.sc.Recv() iter.err = err if err == io.EOF { if iter.cs != nil && iter.merged != nil { iter.cs.addToHistory(iter.merged.Candidates) } return nil, iterator.Done } if err != nil { return nil, err } gcp, err := protoToResponse(resp) if err != nil { iter.err = err return nil, err } // Merge this response in with the ones we've already seen. iter.merged = joinResponses(iter.merged, gcp) // If this is part of a ChatSession, remember the response for the history. return gcp, nil } func protoToResponse(resp *pb.GenerateContentResponse) (*GenerateContentResponse, error) { gcp, err := fromProto[GenerateContentResponse](resp) if err != nil { return nil, err } if gcp == nil { return nil, errors.New("empty response from model") } // Assume a non-nil PromptFeedback is an error. // TODO: confirm. if gcp.PromptFeedback != nil && gcp.PromptFeedback.BlockReason != BlockReasonUnspecified { return nil, &BlockedError{PromptFeedback: gcp.PromptFeedback} } // If any candidate is blocked, error. // TODO: is this too harsh? for _, c := range gcp.Candidates { if c.FinishReason == FinishReasonSafety || c.FinishReason == FinishReasonRecitation { return nil, &BlockedError{Candidate: c} } } return gcp, nil } // MergedResponse returns the result of combining all the streamed responses seen so far. // After iteration completes, the merged response should match the response obtained without streaming // (that is, if [GenerativeModel.GenerateContent] were called). func (iter *GenerateContentResponseIterator) MergedResponse() *GenerateContentResponse { return iter.merged } // CountTokens counts the number of tokens in the content. func (m *GenerativeModel) CountTokens(ctx context.Context, parts ...Part) (*CountTokensResponse, error) { req, err := m.newCountTokensRequest(NewUserContent(parts...)) if err != nil { return nil, err } res, err := m.c.gc.CountTokens(ctx, req) if err != nil { return nil, err } return fromProto[CountTokensResponse](res) } func (m *GenerativeModel) newCountTokensRequest(contents ...*Content) (*pb.CountTokensRequest, error) { gcr, err := m.newGenerateContentRequest(contents...) if err != nil { return nil, err } req := &pb.CountTokensRequest{ Model: m.fullName, GenerateContentRequest: gcr, } debugPrint(req) return req, nil } // Info returns information about the model. func (m *GenerativeModel) Info(ctx context.Context) (*ModelInfo, error) { return m.c.modelInfo(ctx, m.fullName) } func (c *Client) modelInfo(ctx context.Context, fullName string) (*ModelInfo, error) { req := &pb.GetModelRequest{Name: fullName} debugPrint(req) res, err := c.mc.GetModel(ctx, req) if err != nil { return nil, err } return fromProto[ModelInfo](res) } // A BlockedError indicates that the model's response was blocked. // There can be two underlying causes: the prompt or a candidate response. type BlockedError struct { // If non-nil, the model's response was blocked. // Consult the FinishReason field for details. Candidate *Candidate // If non-nil, there was a problem with the prompt. PromptFeedback *PromptFeedback } func (e *BlockedError) Error() string { var b strings.Builder fmt.Fprintf(&b, "blocked: ") if e.Candidate != nil { fmt.Fprintf(&b, "candidate: %s", e.Candidate.FinishReason) } if e.PromptFeedback != nil { if e.Candidate != nil { fmt.Fprintf(&b, ", ") } fmt.Fprintf(&b, "prompt: %v", e.PromptFeedback.BlockReason) } return b.String() } // joinResponses merges the two responses, which should be the result of a streaming call. // The first argument is modified. func joinResponses(dest, src *GenerateContentResponse) *GenerateContentResponse { if dest == nil { return src } dest.Candidates = joinCandidateLists(dest.Candidates, src.Candidates) // Keep dest.PromptFeedback. // TODO: Take the last UsageMetadata. return dest } func joinCandidateLists(dest, src []*Candidate) []*Candidate { indexToSrcCandidate := map[int32]*Candidate{} for _, s := range src { indexToSrcCandidate[s.Index] = s } for _, d := range dest { s := indexToSrcCandidate[d.Index] if s != nil { d.Content = joinContent(d.Content, s.Content) // Take the last of these. d.FinishReason = s.FinishReason // d.FinishMessage = s.FinishMessage d.SafetyRatings = s.SafetyRatings d.CitationMetadata = joinCitationMetadata(d.CitationMetadata, s.CitationMetadata) } } return dest } func joinCitationMetadata(dest, src *CitationMetadata) *CitationMetadata { if dest == nil { return src } if src == nil { return dest } dest.CitationSources = append(dest.CitationSources, src.CitationSources...) return dest } func joinContent(dest, src *Content) *Content { if dest == nil { return src } if src == nil { return dest } // Assume roles are the same. dest.Parts = joinParts(dest.Parts, src.Parts) return dest } func joinParts(dest, src []Part) []Part { return mergeTexts(append(dest, src...)) } func mergeTexts(in []Part) []Part { var out []Part i := 0 for i < len(in) { if t, ok := in[i].(Text); ok { texts := []string{string(t)} var j int for j = i + 1; j < len(in); j++ { if t, ok := in[j].(Text); ok { texts = append(texts, string(t)) } else { break } } // j is just after the last Text. out = append(out, Text(strings.Join(texts, ""))) i = j } else { out = append(out, in[i]) i++ } } return out } // transformSlice applies f to each element of from and returns // a new slice with the results. func transformSlice[From, To any](from []From, f func(From) To) []To { if from == nil { return nil } to := make([]To, len(from)) for i, e := range from { to[i] = f(e) } return to } func fromProto[V interface{ fromProto(P) *V }, P any](p P) (*V, error) { var v V return pvCatchPanic(func() *V { return v.fromProto(p) }) }