194 lines
7.0 KiB
Go
194 lines
7.0 KiB
Go
// Copyright 2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package genai
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
gl "cloud.google.com/go/ai/generativelanguage/apiv1beta"
|
|
pb "cloud.google.com/go/ai/generativelanguage/apiv1beta/generativelanguagepb"
|
|
"google.golang.org/api/iterator"
|
|
durationpb "google.golang.org/protobuf/types/known/durationpb"
|
|
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
|
|
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
|
)
|
|
|
|
type cacheClient = gl.CacheClient
|
|
|
|
var (
|
|
newCacheClient = gl.NewCacheClient
|
|
newCacheRESTClient = gl.NewCacheRESTClient
|
|
)
|
|
|
|
// GenerativeModelFromCachedContent returns a [GenerativeModel] that uses the given [CachedContent].
|
|
// The argument should come from a call to [Client.CreateCachedContent] or [Client.GetCachedContent].
|
|
func (c *Client) GenerativeModelFromCachedContent(cc *CachedContent) *GenerativeModel {
|
|
return &GenerativeModel{
|
|
c: c,
|
|
fullName: cc.Model,
|
|
CachedContentName: cc.Name,
|
|
}
|
|
}
|
|
|
|
// CreateCachedContent creates a new CachedContent.
|
|
// The argument should contain a model name and some data to be cached, which can include
|
|
// contents, a system instruction, tools and/or tool configuration. It can also
|
|
// include an expiration time or TTL. But it should not include a name; the system
|
|
// will generate one.
|
|
//
|
|
// The return value will contain the name, which should be used to refer to the CachedContent
|
|
// in other API calls. It will also hold various metadata like expiration and creation time.
|
|
// It will not contain any of the actual content provided as input.
|
|
//
|
|
// You can use the return value to create a model with [Client.GenerativeModelFromCachedContent].
|
|
// Or you can set [GenerativeModel.CachedContentName] to the name of the CachedContent, in which
|
|
// case you must ensure that the model provided in this call matches the name in the [GenerativeModel].
|
|
func (c *Client) CreateCachedContent(ctx context.Context, cc *CachedContent) (*CachedContent, error) {
|
|
if cc.Name != "" {
|
|
return nil, errors.New("genai.CreateCachedContent: do not provide a name; one will be generated")
|
|
}
|
|
pcc := cc.toProto()
|
|
pcc.Model = Ptr(fullModelName(cc.Model))
|
|
req := &pb.CreateCachedContentRequest{
|
|
CachedContent: pcc,
|
|
}
|
|
debugPrint(req)
|
|
return c.cachedContentFromProto(c.cc.CreateCachedContent(ctx, req))
|
|
}
|
|
|
|
// GetCachedContent retrieves the CachedContent with the given name.
|
|
func (c *Client) GetCachedContent(ctx context.Context, name string) (*CachedContent, error) {
|
|
return c.cachedContentFromProto(c.cc.GetCachedContent(ctx, &pb.GetCachedContentRequest{Name: name}))
|
|
}
|
|
|
|
// DeleteCachedContent deletes the CachedContent with the given name.
|
|
func (c *Client) DeleteCachedContent(ctx context.Context, name string) error {
|
|
return c.cc.DeleteCachedContent(ctx, &pb.DeleteCachedContentRequest{Name: name})
|
|
}
|
|
|
|
// CachedContentToUpdate specifies which fields of a CachedContent to modify in a call to
|
|
// [Client.UpdateCachedContent].
|
|
type CachedContentToUpdate struct {
|
|
// If non-nil, update the expire time or TTL.
|
|
Expiration *ExpireTimeOrTTL
|
|
}
|
|
|
|
// UpdateCachedContent modifies the [CachedContent] according to the values
|
|
// of the [CachedContentToUpdate] struct.
|
|
// It returns the modified CachedContent.
|
|
//
|
|
// The argument CachedContent must have its Name field populated.
|
|
// If its UpdateTime field is non-zero, it will be compared with the update time
|
|
// of the stored CachedContent and the call will fail if they differ.
|
|
// This avoids a race condition when two updates are attempted concurrently.
|
|
// All other fields of the argument CachedContent are ignored.
|
|
func (c *Client) UpdateCachedContent(ctx context.Context, cc *CachedContent, ccu *CachedContentToUpdate) (*CachedContent, error) {
|
|
if ccu == nil || ccu.Expiration == nil {
|
|
return nil, errors.New("genai.UpdateCachedContent: no update specified")
|
|
}
|
|
cc2 := &CachedContent{
|
|
Name: cc.Name,
|
|
UpdateTime: cc.UpdateTime,
|
|
Expiration: *ccu.Expiration,
|
|
}
|
|
mask := "expire_time"
|
|
if ccu.Expiration.ExpireTime.IsZero() {
|
|
mask = "ttl"
|
|
}
|
|
req := &pb.UpdateCachedContentRequest{
|
|
CachedContent: cc2.toProto(),
|
|
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{mask}},
|
|
}
|
|
debugPrint(req)
|
|
return c.cachedContentFromProto(c.cc.UpdateCachedContent(ctx, req))
|
|
}
|
|
|
|
// ListCachedContents lists all the CachedContents associated with the project and location.
|
|
func (c *Client) ListCachedContents(ctx context.Context) *CachedContentIterator {
|
|
return &CachedContentIterator{
|
|
it: c.cc.ListCachedContents(ctx, &pb.ListCachedContentsRequest{}),
|
|
}
|
|
}
|
|
|
|
// A CachedContentIterator iterates over CachedContents.
|
|
type CachedContentIterator struct {
|
|
it *gl.CachedContentIterator
|
|
}
|
|
|
|
// Next returns the next result. Its second return value is iterator.Done if there are no more
|
|
// results. Once Next returns Done, all subsequent calls will return Done.
|
|
func (it *CachedContentIterator) Next() (*CachedContent, error) {
|
|
m, err := it.it.Next()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return (CachedContent{}).fromProto(m), nil
|
|
}
|
|
|
|
// PageInfo supports pagination. See the google.golang.org/api/iterator package for details.
|
|
func (it *CachedContentIterator) PageInfo() *iterator.PageInfo {
|
|
return it.it.PageInfo()
|
|
}
|
|
|
|
func (c *Client) cachedContentFromProto(pcc *pb.CachedContent, err error) (*CachedContent, error) {
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cc := (CachedContent{}).fromProto(pcc)
|
|
return cc, nil
|
|
}
|
|
|
|
// ExpireTimeOrTTL describes the time when a resource expires.
|
|
// If ExpireTime is non-zero, it is the expiration time.
|
|
// Otherwise, the expiration time is the value of TTL ("time to live") added
|
|
// to the current time.
|
|
type ExpireTimeOrTTL struct {
|
|
ExpireTime time.Time
|
|
TTL time.Duration
|
|
}
|
|
|
|
// populateCachedContentTo populates some fields of p from v.
|
|
func populateCachedContentTo(p *pb.CachedContent, v *CachedContent) {
|
|
exp := v.Expiration
|
|
if !exp.ExpireTime.IsZero() {
|
|
p.Expiration = &pb.CachedContent_ExpireTime{
|
|
ExpireTime: timestamppb.New(exp.ExpireTime),
|
|
}
|
|
} else if exp.TTL != 0 {
|
|
p.Expiration = &pb.CachedContent_Ttl{
|
|
Ttl: durationpb.New(exp.TTL),
|
|
}
|
|
}
|
|
// If both fields of v.Expiration are zero, leave p.Expiration unset.
|
|
}
|
|
|
|
// populateCachedContentFrom populates some fields of v from p.
|
|
func populateCachedContentFrom(v *CachedContent, p *pb.CachedContent) {
|
|
if p.Expiration == nil {
|
|
return
|
|
}
|
|
switch e := p.Expiration.(type) {
|
|
case *pb.CachedContent_ExpireTime:
|
|
v.Expiration.ExpireTime = pvTimeFromProto(e.ExpireTime)
|
|
case *pb.CachedContent_Ttl:
|
|
v.Expiration.TTL = e.Ttl.AsDuration()
|
|
default:
|
|
panic(fmt.Sprintf("unknown type of CachedContent.Expiration: %T", p.Expiration))
|
|
}
|
|
}
|