Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions client_azure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package openai

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
)

var (
ErrClientEmptyCallbackURL = errors.New("Error retrieving callback URL (Operation-Location) for image request") //nolint:lll
ErrClientRetrievingCallbackResponse = errors.New("Error retrieving callback response") //nolint:lll
)

type AzureClient = Client

// Azure image request callback response struct.
type CBData []struct {
URL string `json:"url"`
}
type CBResult struct {
Data CBData `json:"data"`
}
type CallBackResponse struct {
Created int64 `json:"created"`
Expires int64 `json:"expires"`
ID string `json:"id"`
Result CBResult `json:"result"`
Status string `json:"status"`
}

func (c *AzureClient) sendAzureRequest(req *http.Request, v any) error {
req.Header.Set("Accept", "application/json; charset=utf-8")

// Check whether Content-Type is already set, Upload Files API requires
// Content-Type == multipart/form-data
contentType := req.Header.Get("Content-Type")
if contentType == "" {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}

c.setCommonHeaders(req)

res, err := c.config.HTTPClient.Do(req)
if err != nil {
return err
}

defer res.Body.Close()

if isFailureStatusCode(res) {
return c.handleErrorResp(res)
}

if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
// Special handling for initial call to Azure DALL-E API.
if strings.Contains(req.URL.Path, "openai/images/generations") {
return c.requestImage(res, v)
}
// Special handling for callBack to Azure DALL-E API.
if strings.Contains(req.URL.Path, "openai/operations/images") {
return c.imageRequestCallback(req, v, res)
}
}

return decodeResponse(res.Body, v)
}

func (c *AzureClient) requestImage(res *http.Response, v any) error {
_, _ = io.Copy(io.Discard, res.Body)
callBackURL := res.Header.Get("Operation-Location")
if callBackURL == "" {
return ErrClientEmptyCallbackURL
}
newReq, err := http.NewRequest("GET", callBackURL, nil)
if err != nil {
return err
}
return c.sendAzureRequest(newReq, v)
}

// Handle image callback response from Azure DALL-E API.
func (c *AzureClient) imageRequestCallback(req *http.Request, v any, res *http.Response) error {
// Retry Sleep seconds for Azure DALL-E 2 callback URL.
var callBackWaitTime = 3

// Wait for the callBack to complete
var result *CallBackResponse
err := json.NewDecoder(res.Body).Decode(&result)
if err != nil {
return ErrClientRetrievingCallbackResponse
}
if result.Status == "" {
return ErrClientRetrievingCallbackResponse
}
if result.Status != "succeeded" {
time.Sleep(time.Duration(callBackWaitTime) * time.Second)
req.Header.Add("Retry", "true")
return c.sendAzureRequest(req, v)
}

// Convert the callBack response to the OpenAI ImageResponse
var urlList []ImageResponseDataInner
for _, data := range result.Result.Data {
urlList = append(urlList, ImageResponseDataInner{URL: data.URL})
}
converted, _ := json.Marshal(ImageResponse{Created: result.Created, Data: urlList})
return decodeResponse(bytes.NewReader(converted), v)
}

// fullURL returns full URL for request.
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
func (c *AzureClient) fullAzureURL(suffix string, args ...any) string {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
switch {
case strings.Contains(suffix, "/models"):
// if suffix is /models change to {endpoint}/openai/models?api-version={api_version}
// https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP
return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion)
case strings.Contains(suffix, "/images"):
// if suffix is /images change to {endpoint}openai/images/generations:submit?api-version={api_version}
return fmt.Sprintf("%s/%s%s:submit?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion)
default:
// /openai/deployments/{model}/chat/completions?api-version={api_version}
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
model, ok := args[0].(string)
if ok {
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
}
}
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, azureDeploymentsPrefix,
azureDeploymentName, suffix, c.config.APIVersion,
)
}
}
18 changes: 18 additions & 0 deletions image_azure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package openai

import (
"context"
"net/http"
)

// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *AzureClient) CreateAzureImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
urlSuffix := "/images/generations"
req, err := c.newRequest(ctx, http.MethodPost, c.fullAzureURL(urlSuffix), withBody(request))
if err != nil {
return
}

err = c.sendAzureRequest(req, &response)
return
}
106 changes: 106 additions & 0 deletions image_azure_api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package openai_test

import (
"bytes"
"strings"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
"fmt"
"net/http"
"testing"
"time"
)

func TestAzureImages(t *testing.T) {
client, server, teardown := setupAzureTestServer()
defer teardown()
server.RegisterHandler("/openai/images/generations:submit", handleAzureImageEndpoint)
server.RegisterHandler("/openai/operations/images/request-id", handleImageCallbackEndpoint)

_, err := client.CreateAzureImage(context.Background(), ImageRequest{
Prompt: "Lorem ipsum",
ResponseFormat: CreateImageResponseFormatURL,
N: 2,
})
checks.NoError(t, err, "Azure CreateImage error")
}

// handleImageEndpoint Handles the images endpoint by the test server.
func handleAzureImageEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// Azure Image Generation request - respond with callback Header only & HTTP accepted status.
if strings.Contains(r.RequestURI, "/openai/images/generations:submit") {
w.Header().Add("Operation-Location", "http://"+r.Host+"/openai/operations/images/request-id")
w.WriteHeader(http.StatusAccepted)
return
}
var imageReq ImageRequest
if imageReq, err = getImageBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := ImageResponse{
Created: time.Now().Unix(),
}
for i := 0; i < imageReq.N; i++ {
imageData := ImageResponseDataInner{}
switch imageReq.ResponseFormat {
case CreateImageResponseFormatURL, "":
imageData.URL = "https://example.com/image.png"
case CreateImageResponseFormatB64JSON:
// This decodes to "{}" in base64.
imageData.B64JSON = "e30K"
default:
http.Error(w, "invalid response format", http.StatusBadRequest)
return
}
res.Data = append(res.Data, imageData)
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// handleImageCallbackEndpoint Handles the callback endpoint by the test server.
func handleImageCallbackEndpoint(w http.ResponseWriter, r *http.Request) {
var err error

// image callback only accepts GET requests
if r.Method != "GET" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}

// Set the status to succeeded if this is a retry request.
status := "running"
if r.Header.Get("Retry") == "true" {
status = "succeeded"
}

cbResponse := CallBackResponse{
Created: time.Now().Unix(),
Status: status,
Result: CBResult{
Data: CBData{
{URL: "http://example.com/image1"},
{URL: "http://example.com/image2"},
},
},
}
cbResponseBytes := new(bytes.Buffer)
err = json.NewEncoder(cbResponseBytes).Encode(cbResponse)
if err != nil {
http.Error(w, "could not write repsonse", http.StatusInternalServerError)
return
}
fmt.Fprintln(w, cbResponseBytes.String())
}