diff --git a/cmd/pierre/main.go b/cmd/pierre/main.go index d296e2f..566f2e2 100644 --- a/cmd/pierre/main.go +++ b/cmd/pierre/main.go @@ -7,7 +7,7 @@ import ( "path/filepath" "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/chatter" - "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/gitadapters" + "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/gitadapters/bitbucket" "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/gitadapters/gitea" "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/pierre" "github.com/alecthomas/kong" @@ -32,7 +32,7 @@ type RepoArgs struct { type LLMConfig struct { Provider string `help:"Provider for llm (ollama or gemini)" required:"" env:"LLM_PROVIDER"` - Endpoint string `help:"Endpoint for provider (only for ollama)" env:"LLM_ENDPOINT"` + BaseURL string `help:"Endpoint for provider (only for ollama)" env:"LLM_BASE_URL"` APIKey string `help:"APIKey for provider" env:"LLM_API_KEY"` Model string `help:"Model to use" env:"LLM_MODEL"` } @@ -84,7 +84,7 @@ func main() { if cfg.Bitbucket.BaseURL == "" { log.Fatal("Bitbucket Base URL is required when using bitbucket provider.") } - git = gitadapters.NewBitbucket(cfg.Bitbucket.BaseURL, cfg.Bitbucket.Token) + git = bitbucket.NewBitbucket(cfg.Bitbucket.BaseURL, cfg.Bitbucket.Token) case "gitea": if cfg.Gitea.BaseURL == "" { log.Fatal("Gitea Base URL is required when using gitea provider.") @@ -105,7 +105,10 @@ func main() { case "gemini": ai, err = chatter.NewGeminiAdapter(context.Background(), cfg.LLM.APIKey, cfg.LLM.Model) case "ollama": - ai, err = chatter.NewOllamaAdapter(cfg.LLM.Endpoint, cfg.LLM.Model) + ai, err = chatter.NewOllamaAdapter(cfg.LLM.BaseURL, cfg.LLM.Model) + case "openai": + ai = chatter.NewOpenAIAdapter(cfg.LLM.APIKey, cfg.LLM.Model, cfg.LLM.BaseURL) + default: log.Fatalf("%s is not a valid llm provider", cfg.LLM.Provider) } diff --git a/go.mod b/go.mod index 0b2aae6..9fee8d2 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,7 @@ require ( github.com/googleapis/gax-go/v2 v2.12.5 // indirect github.com/hashicorp/go-version v1.7.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/sashabaranov/go-openai v1.41.2 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect diff --git a/go.sum b/go.sum index 7e33b82..1948e95 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,8 @@ github.com/ollama/ollama v0.16.0/go.mod h1:FEk95NbAJJZk+t7cLh+bPGTul72j1O3PLLlYN github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM= +github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/internal/chatter/openai.go b/internal/chatter/openai.go new file mode 100644 index 0000000..6665891 --- /dev/null +++ b/internal/chatter/openai.go @@ -0,0 +1,184 @@ +package chatter + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "reflect" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +type OpenAIAdapter struct { + client *openai.Client + model string +} + +func NewOpenAIAdapter(apiKey string, model string, baseURL string) *OpenAIAdapter { + config := openai.DefaultConfig(apiKey) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, // Bypasses the "not standards compliant" error + }, + } + + config.HTTPClient = &http.Client{Transport: tr} + + if baseURL != "" { + config.BaseURL = baseURL + } + if baseURL != "" { + config.BaseURL = baseURL + } + return &OpenAIAdapter{ + client: openai.NewClientWithConfig(config), + model: model, + } +} + +func (a *OpenAIAdapter) Generate(ctx context.Context, messages []Message) (string, error) { + var chatMsgs []openai.ChatCompletionMessage + for _, m := range messages { + chatMsgs = append(chatMsgs, openai.ChatCompletionMessage{ + Role: string(m.Role), + Content: m.Content, + }) + } + + resp, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: a.model, + Messages: chatMsgs, + }) + if err != nil { + return "", err + } + return resp.Choices[0].Message.Content, nil +} + +func (a *OpenAIAdapter) GenerateStructured(ctx context.Context, messages []Message, target any) error { + val := reflect.ValueOf(target) + if val.Kind() != reflect.Ptr { + return fmt.Errorf("target must be a pointer") + } + + elem := val.Elem() + var schemaType reflect.Type + isSlice := elem.Kind() == reflect.Slice + + // 1. Wrap slices in an object because OpenAI requires a root object + if isSlice { + schemaType = reflect.StructOf([]reflect.StructField{ + { + Name: "Items", + Type: elem.Type(), + Tag: `json:"items"`, + }, + }) + } else { + schemaType = elem.Type() + } + + // 2. Build the Schema Map + schemaObj := a.reflectTypeToSchema(schemaType) + + // 3. Convert to json.RawMessage to satisfy the json.Marshaler interface + schemaBytes, err := json.Marshal(schemaObj) + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + var chatMsgs []openai.ChatCompletionMessage + for _, m := range messages { + chatMsgs = append(chatMsgs, openai.ChatCompletionMessage{ + Role: string(m.Role), + Content: m.Content, + }) + } + + // 4. Send Request + req := openai.ChatCompletionRequest{ + Model: a.model, + Messages: chatMsgs, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "output_schema", + Strict: true, + Schema: json.RawMessage(schemaBytes), + }, + }, + } + + resp, err := a.client.CreateChatCompletion(ctx, req) + if err != nil { + return err + } + + content := resp.Choices[0].Message.Content + + // 5. Unmarshal and Unwrap if necessary + if isSlice { + temp := struct { + Items json.RawMessage `json:"items"` + }{} + if err := json.Unmarshal([]byte(content), &temp); err != nil { + return err + } + return json.Unmarshal(temp.Items, target) + } + + return json.Unmarshal([]byte(content), target) +} + +func (a *OpenAIAdapter) reflectTypeToSchema(t reflect.Type) jsonschema.Definition { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Struct: + def := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: make(map[string]jsonschema.Definition), + AdditionalProperties: false, + Required: []string{}, + } + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + name := field.Tag.Get("json") + if name == "" || name == "-" { + name = field.Name + } + def.Properties[name] = a.reflectTypeToSchema(field.Type) + def.Required = append(def.Required, name) + } + return def + + case reflect.Slice, reflect.Array: + items := a.reflectTypeToSchema(t.Elem()) + return jsonschema.Definition{ + Type: jsonschema.Array, + Items: &items, + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return jsonschema.Definition{Type: jsonschema.Integer} + case reflect.Float32, reflect.Float64: + return jsonschema.Definition{Type: jsonschema.Number} + case reflect.Bool: + return jsonschema.Definition{Type: jsonschema.Boolean} + default: + return jsonschema.Definition{Type: jsonschema.String} + } +} + +func (a *OpenAIAdapter) GetProviderName() string { + return "OpenAI (" + a.model + ")" +} diff --git a/internal/gitadapters/base.go b/internal/gitadapters/base.go deleted file mode 100644 index 192985d..0000000 --- a/internal/gitadapters/base.go +++ /dev/null @@ -1,31 +0,0 @@ -package gitadapters - -import ( - "fmt" - "io" - "net/http" - "net/url" -) - -type baseHTTP struct { - baseURL string - bearerToken string -} - -func (b *baseHTTP) createRequest(method string, body io.Reader, path ...string) (r *http.Request, err error) { - target, err := url.JoinPath(b.baseURL, path...) - if err != nil { - err = fmt.Errorf("can not parse path: %w", err) - return - } - req, err := http.NewRequest(method, target, body) - if err != nil { - return nil, err - } - - if b.bearerToken != "" { - req.Header.Set("Authorization", "Bearer "+b.bearerToken) - } - - return req, nil -} diff --git a/internal/gitadapters/baseadapter/rest.go b/internal/gitadapters/baseadapter/rest.go index 0c8f543..2b82bde 100644 --- a/internal/gitadapters/baseadapter/rest.go +++ b/internal/gitadapters/baseadapter/rest.go @@ -1,6 +1,9 @@ package baseadapter import ( + "bytes" + "context" + "encoding/json" "fmt" "io" "net/http" @@ -12,17 +15,45 @@ type Rest struct { bearerToken string } -func (b *Rest) createRequest(method string, body io.Reader, path ...string) (r *http.Request, err error) { +func NewRest(baseURL string, bearerToken string) Rest { + return Rest{ + baseURL: baseURL, + bearerToken: bearerToken, + } +} + +const defaultBodyBufferSize = 100 + +func (b *Rest) CreateRequest(ctx context.Context, method string, body any, path ...string) (r *http.Request, err error) { target, err := url.JoinPath(b.baseURL, path...) if err != nil { err = fmt.Errorf("can not parse path: %w", err) return } - req, err := http.NewRequest(method, target, body) + + var bodyReader io.Reader + if body != nil { + bodyBuff := bytes.NewBuffer(make([]byte, 0, defaultBodyBufferSize)) + + err = json.NewEncoder(bodyBuff).Encode(body) + if err != nil { + return + } + + bodyReader = bodyBuff + } + + req, err := http.NewRequest(method, target, bodyReader) if err != nil { return nil, err } + req = req.WithContext(ctx) + + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if b.bearerToken != "" { req.Header.Set("Authorization", "Bearer "+b.bearerToken) } diff --git a/internal/gitadapters/bitbucket.go b/internal/gitadapters/bitbucket.go deleted file mode 100644 index fd14fde..0000000 --- a/internal/gitadapters/bitbucket.go +++ /dev/null @@ -1,46 +0,0 @@ -package gitadapters - -import ( - "fmt" - "io" - "net/http" - - "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/pierre" -) - -type BitbucketAdapter struct { - baseHTTP -} - -func NewBitbucket(baseURL string, bearerToken string) *BitbucketAdapter { - return &BitbucketAdapter{ - baseHTTP{ - baseURL: baseURL, - bearerToken: bearerToken, - }, - } -} - -func (b *BitbucketAdapter) GetDiff(projectKey, repositorySlug string, pullRequestID int) (diff io.Reader, err error) { - r, err := b.createRequest( - http.MethodGet, - nil, - "/rest/api/1.0/projects/", projectKey, "repos", repositorySlug, "pull-requests", fmt.Sprintf("%d.diff", pullRequestID), - ) - if err != nil { - return - } - - response, err := http.DefaultClient.Do(r) - if err != nil { - return - } - diff = response.Body - return -} - -func (b *BitbucketAdapter) AddComment(projectKey, repositorySlug string, pullRequestID int, comment pierre.Comment) error { - fmt.Printf("[MOCK BITBUCKET] Adding comment to PR %s/%s #%d: %s at %s:%d\n", - projectKey, repositorySlug, pullRequestID, comment.Message, comment.File, comment.Line) - return nil -} diff --git a/internal/gitadapters/bitbucket/controller.go b/internal/gitadapters/bitbucket/controller.go new file mode 100644 index 0000000..75b3a30 --- /dev/null +++ b/internal/gitadapters/bitbucket/controller.go @@ -0,0 +1,101 @@ +package bitbucket + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/pierre" +) + +func (b *BitbucketAdapter) GetDiff(ctx context.Context, projectKey, repositorySlug string, pullRequestID int) (diff io.ReadCloser, err error) { + r, err := b.CreateRequest( + ctx, + http.MethodGet, + nil, + "/projects/", projectKey, "repos", repositorySlug, "pull-requests", fmt.Sprintf("%d.diff", pullRequestID), + ) + if err != nil { + return + } + + response, err := http.DefaultClient.Do(r) + if err != nil { + return + } + + if response.StatusCode != http.StatusOK { + sb := &strings.Builder{} + io.Copy(sb, response.Body) + err = fmt.Errorf("error while fetching bitbucket diff staus %d, body %s", response.Status, sb.String()) + } + + diff = response.Body + return +} + +func (b *BitbucketAdapter) GetPR(ctx context.Context, projectKey, repositorySlug string, pullRequestID int) (pr PullRequest, err error) { + r, err := b.CreateRequest( + ctx, + http.MethodGet, + nil, + "/projects/", projectKey, "repos", repositorySlug, "pull-requests", strconv.Itoa(pullRequestID), + ) + + response, err := http.DefaultClient.Do(r) + defer response.Body.Close() // Add this + if err != nil { + return + } + + err = json.NewDecoder(response.Body).Decode(&pr) + + return +} + +func (b *BitbucketAdapter) AddComment(ctx context.Context, owner, repo string, prID int, comment pierre.Comment) (err error) { + // pr, err := b.GetPR(ctx, owner, repo, prID) + // if err != nil { + // return + // } + + commentDTO := Comment{ + Content: comment.Message, + Anchor: Anchor{ + Path: comment.File, + Line: comment.Line, + LineType: "ADDED", + FileType: "TO", + DiffType: "EFFECTIVE", + // FromHash: pr.ToRef.LatestCommit, + // ToHash: pr.FromRef.LatestCommit, + }, + } + + r, err := b.CreateRequest(ctx, + http.MethodPost, + commentDTO, + "/projects/", owner, "/repos/", repo, "/pull-requests/", strconv.Itoa(prID), "/comments", + ) + if err != nil { + return + } + + response, err := http.DefaultClient.Do(r) + defer response.Body.Close() // Add this + if err != nil { + return err + } + + if response.StatusCode >= 300 || response.StatusCode < 200 { + sb := &strings.Builder{} + io.Copy(sb, response.Body) + err = fmt.Errorf("error while creating comment staus %d, body %s", response.StatusCode, sb.String()) + } + + return err +} diff --git a/internal/gitadapters/bitbucket/model.go b/internal/gitadapters/bitbucket/model.go new file mode 100644 index 0000000..4e5b4d9 --- /dev/null +++ b/internal/gitadapters/bitbucket/model.go @@ -0,0 +1,44 @@ +package bitbucket + +type Anchor struct { + Path string `json:"path"` + Line int `json:"line"` + LineType string `json:"lineType,omitempty"` + FileType string `json:"fileType"` + FromHash string `json:"fromHash,omitempty"` + ToHash string `json:"toHash,omitempty"` + DiffType string `json:"diffType,omitempty"` +} + +type Comment struct { + Content string `json:"text"` + Anchor Anchor `json:"anchor"` +} + +type PullRequest struct { + ID int64 `json:"id"` + Version int `json:"version"` + Title string `json:"title"` + State string `json:"state"` + Open bool `json:"open"` + Closed bool `json:"closed"` + FromRef Ref `json:"fromRef"` + ToRef Ref `json:"toRef"` + Description string `json:"description"` +} + +type Ref struct { + ID string `json:"id"` + DisplayID string `json:"displayId"` + LatestCommit string `json:"latestCommit"` + Repository Repository `json:"repository"` +} + +type Repository struct { + Slug string `json:"slug"` + Project Project `json:"project"` +} + +type Project struct { + Key string `json:"key"` +} diff --git a/internal/gitadapters/bitbucket/resource.go b/internal/gitadapters/bitbucket/resource.go new file mode 100644 index 0000000..313d7d9 --- /dev/null +++ b/internal/gitadapters/bitbucket/resource.go @@ -0,0 +1,20 @@ +package bitbucket + +import ( + "strings" + + "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/gitadapters/baseadapter" +) + +type BitbucketAdapter struct { + baseadapter.Rest +} + +func NewBitbucket(baseURL string, bearerToken string) *BitbucketAdapter { + baseURL, _ = strings.CutSuffix(baseURL, "/") + baseURL += "/rest/api/1.0" + + return &BitbucketAdapter{ + Rest: baseadapter.NewRest(baseURL, bearerToken), + } +} diff --git a/internal/gitadapters/gitea/adapter.go b/internal/gitadapters/gitea/adapter.go index 99c87fc..bf72228 100644 --- a/internal/gitadapters/gitea/adapter.go +++ b/internal/gitadapters/gitea/adapter.go @@ -2,10 +2,11 @@ package gitea import ( "bytes" + "context" "io" - "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/pierre" "code.gitea.io/sdk/gitea" + "git.schreifuchs.ch/schreifuchs/pierre-bot/internal/pierre" ) type Adapter struct { @@ -22,15 +23,17 @@ func New(baseURL, token string) (*Adapter, error) { }, nil } -func (g *Adapter) GetDiff(owner, repo string, prID int) (io.Reader, error) { +func (g *Adapter) GetDiff(ctx context.Context, owner, repo string, prID int) (io.ReadCloser, error) { + g.client.SetContext(ctx) diff, _, err := g.client.GetPullRequestDiff(owner, repo, int64(prID), gitea.PullRequestDiffOptions{}) if err != nil { return nil, err } - return bytes.NewReader(diff), nil + return io.NopCloser(bytes.NewReader(diff)), nil } -func (g *Adapter) AddComment(owner, repo string, prID int, comment pierre.Comment) error { +func (g *Adapter) AddComment(ctx context.Context, owner, repo string, prID int, comment pierre.Comment) error { + g.client.SetContext(ctx) opts := gitea.CreatePullReviewOptions{ State: gitea.ReviewStateComment, Comments: []gitea.CreatePullReviewComment{ diff --git a/internal/pierre/resource.go b/internal/pierre/resource.go index bfd6e19..ecdc8b6 100644 --- a/internal/pierre/resource.go +++ b/internal/pierre/resource.go @@ -20,10 +20,11 @@ func New(chat ChatAdapter, git GitAdapter) *Service { } type GitAdapter interface { - GetDiff(owner, repo string, prID int) (io.Reader, error) - AddComment(owner, repo string, prID int, comment Comment) error + GetDiff(ctx context.Context, owner, repo string, prID int) (io.ReadCloser, error) + AddComment(ctx context.Context, owner, repo string, prID int, comment Comment) error } type ChatAdapter interface { GenerateStructured(ctx context.Context, messages []chatter.Message, target interface{}) error + GetProviderName() string } diff --git a/internal/pierre/review.go b/internal/pierre/review.go index 4b782d7..d3fb8c3 100644 --- a/internal/pierre/review.go +++ b/internal/pierre/review.go @@ -8,7 +8,8 @@ import ( func (s *Service) MakeReview(ctx context.Context, organisation string, repo string, prID int) error { // Fetch Diff using positional args from shared RepoArgs - diff, err := s.git.GetDiff(organisation, repo, prID) + diff, err := s.git.GetDiff(ctx, organisation, repo, prID) + defer diff.Close() if err != nil { return fmt.Errorf("error fetching diff: %w", err) } @@ -21,11 +22,14 @@ func (s *Service) MakeReview(ctx context.Context, organisation string, repo stri fmt.Printf("Analysis complete. Found %d issues.\n---\n", len(comments)) + model := s.chat.GetProviderName() + for _, c := range comments { + c.Message = fmt.Sprintf("%s (Generated by: %s)", c.Message, model) fmt.Printf("File: %s\nLine: %d\nMessage: %s\n%s\n", c.File, c.Line, c.Message, "---") - if err := s.git.AddComment(organisation, repo, prID, c); err != nil { + if err := s.git.AddComment(ctx, organisation, repo, prID, c); err != nil { log.Printf("Failed to add comment: %v", err) } }