Reimplement default values (#65)

This commit is contained in:
Kirill
2024-01-29 23:22:57 +03:00
committed by GitHub
parent 35832e6269
commit 2f94b8c774
9 changed files with 182 additions and 76 deletions

View File

@@ -25,7 +25,7 @@ jobs:
- name: Setup go - name: Setup go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: '1.19' go-version: '1.20'
- name: Setup - name: Setup
run: | run: |
go install github.com/mattn/goveralls@latest go install github.com/mattn/goveralls@latest

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect"
) )
type HttpClient interface { type HttpClient interface {
@@ -36,6 +37,14 @@ func NewMockClient(handler http.HandlerFunc) *Client {
} }
func (c Client) newRequest(ctx context.Context, method string, uri string, body interface{}) (*http.Request, error) { func (c Client) newRequest(ctx context.Context, method string, uri string, body interface{}) (*http.Request, error) {
// Set default values for empty fields if `default` tag is present
// And body is not nil
if body != nil {
if err := getDefaultValues(reflect.ValueOf(body)); err != nil {
return nil, err
}
}
bodyJson, err := json.Marshal(body) bodyJson, err := json.Marshal(body)
if err != nil { if err != nil {
return nil, err return nil, err

130
core.go
View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
"strconv"
"testing" "testing"
"time" "time"
) )
@@ -32,51 +33,103 @@ func (r Response) CopyCommonResponse(rhs *CommonResponse) {
rhs.Message = r.Message rhs.Message = r.Message
} }
func getDefaultValues(v interface{}) (map[string]string, error) { func getDefaultValues(v reflect.Value) error {
isNil, err := isZero(v) vValue := v.Elem()
if err != nil { vType := vValue.Type()
return make(map[string]string), err
}
if isNil {
return make(map[string]string), nil
}
out := make(map[string]string)
vType := reflect.TypeOf(v).Elem()
vValue := reflect.ValueOf(v).Elem()
for i := 0; i < vType.NumField(); i++ { for i := 0; i < vType.NumField(); i++ {
field := vType.Field(i) field := vType.Field(i)
tag := field.Tag.Get("json")
defaultValue := field.Tag.Get("default")
if field.Type.Kind() == reflect.Slice { switch field.Type.Kind() {
// Attach any slices as query params case reflect.Slice:
fieldVal := vValue.Field(i) for j := 0; j < vValue.Field(i).Len(); j++ {
for j := 0; j < fieldVal.Len(); j++ { // skip if slice type is primitive
out[tag] = fmt.Sprintf("%v", fieldVal.Index(j)) if vValue.Field(i).Index(j).Kind() != reflect.Struct {
}
} else {
// Add any scalar values as query params
fieldVal := fmt.Sprintf("%v", vValue.Field(i))
// If no value was set by the user, use the default
// value specified in the struct tag.
if fieldVal == "" || fieldVal == "0" {
if defaultValue == "" {
continue continue
} }
fieldVal = defaultValue // Attach any slices as query params
err := getDefaultValues(vValue.Field(i).Index(j).Addr())
if err != nil {
return err
}
}
case reflect.String:
isNil, err := isZero(vValue.Field(i).Addr())
if err != nil {
return err
}
if !isNil {
continue
} }
out[tag] = fmt.Sprintf("%v", fieldVal) defaultValue, ok := field.Tag.Lookup("default")
if !ok {
continue
}
vValue.Field(i).SetString(defaultValue)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
isNil, err := isZero(vValue.Field(i).Addr())
if err != nil {
return err
}
if !isNil {
continue
}
defaultValue, ok := field.Tag.Lookup("default")
if !ok {
continue
}
defaultValueInt, err := strconv.Atoi(defaultValue)
if err != nil {
return err
}
vValue.Field(i).SetInt(int64(defaultValueInt))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
isNil, err := isZero(vValue.Field(i).Addr())
if err != nil {
return err
}
if !isNil {
continue
}
defaultValue, ok := field.Tag.Lookup("default")
if !ok {
continue
}
defaultValueUint, err := strconv.ParseUint(defaultValue, 10, 64)
if err != nil {
return err
}
vValue.Field(i).SetUint(uint64(defaultValueUint))
case reflect.Struct:
err := getDefaultValues(vValue.Field(i).Addr())
if err != nil {
return err
}
case reflect.Ptr:
isNil, err := isZero(vValue.Field(i).Addr())
if err != nil {
return err
}
if isNil {
continue
}
if err := getDefaultValues(vValue.Field(i)); err != nil {
return err
}
default:
continue
} }
} }
return out, nil return nil
} }
func buildRawQuery(req *http.Request, v interface{}) (string, error) { func buildRawQuery(req *http.Request, v interface{}) (string, error) {
@@ -86,23 +139,20 @@ func buildRawQuery(req *http.Request, v interface{}) (string, error) {
return query.Encode(), nil return query.Encode(), nil
} }
values, err := getDefaultValues(v) err := getDefaultValues(reflect.ValueOf(v))
if err != nil { if err != nil {
return "", err return "", err
} }
for k, v := range values {
query.Add(k, v)
}
return query.Encode(), nil return query.Encode(), nil
} }
func isZero(v interface{}) (bool, error) { func isZero(v reflect.Value) (bool, error) {
t := reflect.TypeOf(v) t := v.Elem().Type()
if !t.Comparable() { if !t.Comparable() {
return false, fmt.Errorf("type is not comparable: %v", t) return false, fmt.Errorf("type is not comparable: %v", t)
} }
return v == reflect.Zero(t).Interface(), nil return reflect.Zero(t).Equal(v.Elem()), nil
} }
func TimeFromString(t *testing.T, format, datetime string) time.Time { func TimeFromString(t *testing.T, format, datetime string) time.Time {

View File

@@ -1,34 +1,56 @@
package core package core
import ( import (
"log" "reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
type TestTagDefaultValueStruct struct { type DefaultStructure struct {
TestString string `json:"test_string" default:"something"` EmptyField string `json:"empty_field" default:"empty_string"`
TestNumber int `json:"test_number" default:"12"` Field string `json:"field" default:"string"`
} }
func TestTagDefaultValue(t *testing.T) { type DefaultRequest struct {
testStruct := &TestTagDefaultValueStruct{} Field int `json:"field" default:"100"`
EmptyField int `json:"empty_field" default:"14"`
values, err := getDefaultValues(testStruct) Structure DefaultStructure `json:"structure"`
if err != nil { Slice []DefaultStructure `json:"slice"`
log.Fatalf("error when getting default values from tags: %s", err) OptionalStructure *DefaultStructure `json:"optional_structure"`
EmptyOptionalStructure *DefaultStructure `json:"empty_optional_structure"`
} }
expected := map[string]string{ func TestDefaultValues(t *testing.T) {
"test_string": "something", req := &DefaultRequest{
"test_number": "12", Field: 50,
Structure: DefaultStructure{
Field: "something",
},
Slice: []DefaultStructure{
{
Field: "something",
},
{
Field: "something",
},
},
OptionalStructure: &DefaultStructure{
Field: "something",
},
} }
err := getDefaultValues(reflect.ValueOf(req))
assert.Nil(t, err)
if len(values) != len(expected) { assert.Equal(t, 50, req.Field)
log.Fatalf("expected equal length of values and expected: expected: %d, got: %d", len(expected), len(values)) assert.Equal(t, 14, req.EmptyField)
} assert.Equal(t, "something", req.Structure.Field)
for expKey, expValue := range expected { assert.Equal(t, "empty_string", req.Structure.EmptyField)
if expValue != values[expKey] { assert.Equal(t, "something", req.Slice[0].Field)
log.Fatalf("not equal values for key %s", expKey) assert.Equal(t, "something", req.Slice[1].Field)
} assert.Equal(t, "empty_string", req.Slice[1].EmptyField)
} assert.Equal(t, "empty_string", req.Slice[1].EmptyField)
assert.Equal(t, "something", req.OptionalStructure.Field)
assert.Equal(t, "empty_string", req.OptionalStructure.EmptyField)
assert.Equal(t, (*DefaultStructure)(nil), req.EmptyOptionalStructure)
} }

10
go.mod
View File

@@ -1,3 +1,11 @@
module github.com/diphantxm/ozon-api-client module github.com/diphantxm/ozon-api-client
go 1.19 go 1.20
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/stretchr/testify v1.8.4 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

17
go.sum Normal file
View File

@@ -0,0 +1,17 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1618,7 +1618,7 @@ func (c Products) GetProductRangeLimit(ctx context.Context) (*GetProductRangeLim
resp := &GetProductRangeLimitResponse{} resp := &GetProductRangeLimitResponse{}
response, err := c.client.Request(ctx, http.MethodPost, url, &struct{}{}, resp, nil) response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -667,7 +667,7 @@ func (c Returns) IsGiveoutEnabled(ctx context.Context) (*IsGiveoutEnabledRespons
resp := &IsGiveoutEnabledResponse{} resp := &IsGiveoutEnabledResponse{}
response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -697,7 +697,7 @@ func (c Returns) GetGiveoutPDF(ctx context.Context) (*GetGiveoutResponse, error)
resp := &GetGiveoutResponse{} resp := &GetGiveoutResponse{}
response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -714,7 +714,7 @@ func (c Returns) GetGiveoutPNG(ctx context.Context) (*GetGiveoutResponse, error)
resp := &GetGiveoutResponse{} resp := &GetGiveoutResponse{}
response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -739,7 +739,7 @@ func (c Returns) GetGiveoutBarcode(ctx context.Context) (*GetGiveoutBarcodeRespo
resp := &GetGiveoutBarcodeResponse{} resp := &GetGiveoutBarcodeResponse{}
response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -758,7 +758,7 @@ func (c Returns) ResetGiveoutBarcode(ctx context.Context) (*GetGiveoutResponse,
resp := &GetGiveoutResponse{} resp := &GetGiveoutResponse{}
response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -814,7 +814,7 @@ func (c Returns) GetGiveoutList(ctx context.Context, params *GetGiveoutListParam
resp := &GetGiveoutListResponse{} resp := &GetGiveoutListResponse{}
response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -867,7 +867,7 @@ func (c Returns) GetGiveoutInfo(ctx context.Context, params *GetGiveoutInfoParam
resp := &GetGiveoutInfoResponse{} resp := &GetGiveoutInfoResponse{}
response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -181,7 +181,7 @@ func (c Warehouses) GetListOfDeliveryMethods(ctx context.Context, params *GetLis
resp := &GetListOfDeliveryMethodsResponse{} resp := &GetListOfDeliveryMethodsResponse{}
response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) response, err := c.client.Request(ctx, http.MethodPost, url, params, resp, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }