From 2f94b8c77424589f6fd317d64663f65e99b2f335 Mon Sep 17 00:00:00 2001 From: Kirill Date: Mon, 29 Jan 2024 23:22:57 +0300 Subject: [PATCH] Reimplement default values (#65) --- .github/workflows/tests.yml | 2 +- client.go | 9 +++ core.go | 130 +++++++++++++++++++++++++----------- core_test.go | 72 +++++++++++++------- go.mod | 10 ++- go.sum | 17 +++++ ozon/products.go | 2 +- ozon/returns.go | 14 ++-- ozon/warehouses.go | 2 +- 9 files changed, 182 insertions(+), 76 deletions(-) create mode 100644 go.sum diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ff5d6cd..5a2e8ae 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,7 +25,7 @@ jobs: - name: Setup go uses: actions/setup-go@v2 with: - go-version: '1.19' + go-version: '1.20' - name: Setup run: | go install github.com/mattn/goveralls@latest diff --git a/client.go b/client.go index 7ffcaba..1b9f06a 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "reflect" ) type HttpClient interface { @@ -36,6 +37,14 @@ func NewMockClient(handler http.HandlerFunc) *Client { } func (c Client) newRequest(ctx context.Context, method string, uri string, body interface{}) (*http.Request, error) { + // Set default values for empty fields if `default` tag is present + // And body is not nil + if body != nil { + if err := getDefaultValues(reflect.ValueOf(body)); err != nil { + return nil, err + } + } + bodyJson, err := json.Marshal(body) if err != nil { return nil, err diff --git a/core.go b/core.go index a3bb8c9..7d5209c 100644 --- a/core.go +++ b/core.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "reflect" + "strconv" "testing" "time" ) @@ -32,51 +33,103 @@ func (r Response) CopyCommonResponse(rhs *CommonResponse) { rhs.Message = r.Message } -func getDefaultValues(v interface{}) (map[string]string, error) { - isNil, err := isZero(v) - if err != nil { - return make(map[string]string), err - } - - if isNil { - return make(map[string]string), nil - } - - out := make(map[string]string) - - vType := reflect.TypeOf(v).Elem() - vValue := reflect.ValueOf(v).Elem() +func getDefaultValues(v reflect.Value) error { + vValue := v.Elem() + vType := vValue.Type() for i := 0; i < vType.NumField(); i++ { field := vType.Field(i) - tag := field.Tag.Get("json") - defaultValue := field.Tag.Get("default") - if field.Type.Kind() == reflect.Slice { - // Attach any slices as query params - fieldVal := vValue.Field(i) - for j := 0; j < fieldVal.Len(); j++ { - out[tag] = fmt.Sprintf("%v", fieldVal.Index(j)) - } - } else { - // Add any scalar values as query params - fieldVal := fmt.Sprintf("%v", vValue.Field(i)) - - // If no value was set by the user, use the default - // value specified in the struct tag. - if fieldVal == "" || fieldVal == "0" { - if defaultValue == "" { + switch field.Type.Kind() { + case reflect.Slice: + for j := 0; j < vValue.Field(i).Len(); j++ { + // skip if slice type is primitive + if vValue.Field(i).Index(j).Kind() != reflect.Struct { continue } - fieldVal = defaultValue + // Attach any slices as query params + err := getDefaultValues(vValue.Field(i).Index(j).Addr()) + if err != nil { + return err + } + } + case reflect.String: + isNil, err := isZero(vValue.Field(i).Addr()) + if err != nil { + return err + } + if !isNil { + continue } - out[tag] = fmt.Sprintf("%v", fieldVal) + defaultValue, ok := field.Tag.Lookup("default") + if !ok { + continue + } + + vValue.Field(i).SetString(defaultValue) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + isNil, err := isZero(vValue.Field(i).Addr()) + if err != nil { + return err + } + if !isNil { + continue + } + + defaultValue, ok := field.Tag.Lookup("default") + if !ok { + continue + } + defaultValueInt, err := strconv.Atoi(defaultValue) + if err != nil { + return err + } + + vValue.Field(i).SetInt(int64(defaultValueInt)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + isNil, err := isZero(vValue.Field(i).Addr()) + if err != nil { + return err + } + if !isNil { + continue + } + + defaultValue, ok := field.Tag.Lookup("default") + if !ok { + continue + } + defaultValueUint, err := strconv.ParseUint(defaultValue, 10, 64) + if err != nil { + return err + } + + vValue.Field(i).SetUint(uint64(defaultValueUint)) + case reflect.Struct: + err := getDefaultValues(vValue.Field(i).Addr()) + if err != nil { + return err + } + case reflect.Ptr: + isNil, err := isZero(vValue.Field(i).Addr()) + if err != nil { + return err + } + if isNil { + continue + } + + if err := getDefaultValues(vValue.Field(i)); err != nil { + return err + } + default: + continue } } - return out, nil + return nil } func buildRawQuery(req *http.Request, v interface{}) (string, error) { @@ -86,23 +139,20 @@ func buildRawQuery(req *http.Request, v interface{}) (string, error) { return query.Encode(), nil } - values, err := getDefaultValues(v) + err := getDefaultValues(reflect.ValueOf(v)) if err != nil { return "", err } - for k, v := range values { - query.Add(k, v) - } return query.Encode(), nil } -func isZero(v interface{}) (bool, error) { - t := reflect.TypeOf(v) +func isZero(v reflect.Value) (bool, error) { + t := v.Elem().Type() if !t.Comparable() { return false, fmt.Errorf("type is not comparable: %v", t) } - return v == reflect.Zero(t).Interface(), nil + return reflect.Zero(t).Equal(v.Elem()), nil } func TimeFromString(t *testing.T, format, datetime string) time.Time { diff --git a/core_test.go b/core_test.go index dd46581..819f80c 100644 --- a/core_test.go +++ b/core_test.go @@ -1,34 +1,56 @@ package core import ( - "log" + "reflect" "testing" + + "github.com/stretchr/testify/assert" ) -type TestTagDefaultValueStruct struct { - TestString string `json:"test_string" default:"something"` - TestNumber int `json:"test_number" default:"12"` +type DefaultStructure struct { + EmptyField string `json:"empty_field" default:"empty_string"` + Field string `json:"field" default:"string"` } -func TestTagDefaultValue(t *testing.T) { - testStruct := &TestTagDefaultValueStruct{} - - values, err := getDefaultValues(testStruct) - if err != nil { - log.Fatalf("error when getting default values from tags: %s", err) - } - - expected := map[string]string{ - "test_string": "something", - "test_number": "12", - } - - if len(values) != len(expected) { - log.Fatalf("expected equal length of values and expected: expected: %d, got: %d", len(expected), len(values)) - } - for expKey, expValue := range expected { - if expValue != values[expKey] { - log.Fatalf("not equal values for key %s", expKey) - } - } +type DefaultRequest struct { + Field int `json:"field" default:"100"` + EmptyField int `json:"empty_field" default:"14"` + Structure DefaultStructure `json:"structure"` + Slice []DefaultStructure `json:"slice"` + OptionalStructure *DefaultStructure `json:"optional_structure"` + EmptyOptionalStructure *DefaultStructure `json:"empty_optional_structure"` +} + +func TestDefaultValues(t *testing.T) { + req := &DefaultRequest{ + Field: 50, + Structure: DefaultStructure{ + Field: "something", + }, + Slice: []DefaultStructure{ + { + Field: "something", + }, + { + Field: "something", + }, + }, + OptionalStructure: &DefaultStructure{ + Field: "something", + }, + } + err := getDefaultValues(reflect.ValueOf(req)) + assert.Nil(t, err) + + assert.Equal(t, 50, req.Field) + assert.Equal(t, 14, req.EmptyField) + assert.Equal(t, "something", req.Structure.Field) + assert.Equal(t, "empty_string", req.Structure.EmptyField) + assert.Equal(t, "something", req.Slice[0].Field) + assert.Equal(t, "something", req.Slice[1].Field) + assert.Equal(t, "empty_string", req.Slice[1].EmptyField) + assert.Equal(t, "empty_string", req.Slice[1].EmptyField) + assert.Equal(t, "something", req.OptionalStructure.Field) + assert.Equal(t, "empty_string", req.OptionalStructure.EmptyField) + assert.Equal(t, (*DefaultStructure)(nil), req.EmptyOptionalStructure) } diff --git a/go.mod b/go.mod index f81cfaa..761d518 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module github.com/diphantxm/ozon-api-client -go 1.19 +go 1.20 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/testify v1.8.4 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5bddba9 --- /dev/null +++ b/go.sum @@ -0,0 +1,17 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ozon/products.go b/ozon/products.go index e3ef5e3..dbe6a30 100644 --- a/ozon/products.go +++ b/ozon/products.go @@ -1618,7 +1618,7 @@ func (c Products) GetProductRangeLimit(ctx context.Context) (*GetProductRangeLim resp := &GetProductRangeLimitResponse{} - response, err := c.client.Request(ctx, http.MethodPost, url, &struct{}{}, resp, nil) + response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) if err != nil { return nil, err } diff --git a/ozon/returns.go b/ozon/returns.go index fa5b3e7..d023220 100644 --- a/ozon/returns.go +++ b/ozon/returns.go @@ -667,7 +667,7 @@ func (c Returns) IsGiveoutEnabled(ctx context.Context) (*IsGiveoutEnabledRespons resp := &IsGiveoutEnabledResponse{} - response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) + response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) if err != nil { return nil, err } @@ -697,7 +697,7 @@ func (c Returns) GetGiveoutPDF(ctx context.Context) (*GetGiveoutResponse, error) resp := &GetGiveoutResponse{} - response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) + response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) if err != nil { return nil, err } @@ -714,7 +714,7 @@ func (c Returns) GetGiveoutPNG(ctx context.Context) (*GetGiveoutResponse, error) resp := &GetGiveoutResponse{} - response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) + response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) if err != nil { return nil, err } @@ -739,7 +739,7 @@ func (c Returns) GetGiveoutBarcode(ctx context.Context) (*GetGiveoutBarcodeRespo resp := &GetGiveoutBarcodeResponse{} - response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) + response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) if err != nil { return nil, err } @@ -758,7 +758,7 @@ func (c Returns) ResetGiveoutBarcode(ctx context.Context) (*GetGiveoutResponse, resp := &GetGiveoutResponse{} - response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) + response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) if err != nil { return nil, err } @@ -814,7 +814,7 @@ func (c Returns) GetGiveoutList(ctx context.Context, params *GetGiveoutListParam resp := &GetGiveoutListResponse{} - response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) + response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) if err != nil { return nil, err } @@ -867,7 +867,7 @@ func (c Returns) GetGiveoutInfo(ctx context.Context, params *GetGiveoutInfoParam resp := &GetGiveoutInfoResponse{} - response, err := c.client.Request(ctx, http.MethodPost, url, struct{}{}, resp, nil) + response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) if err != nil { return nil, err } diff --git a/ozon/warehouses.go b/ozon/warehouses.go index 0fd2905..715f3ac 100644 --- a/ozon/warehouses.go +++ b/ozon/warehouses.go @@ -181,7 +181,7 @@ func (c Warehouses) GetListOfDeliveryMethods(ctx context.Context, params *GetLis resp := &GetListOfDeliveryMethodsResponse{} - response, err := c.client.Request(ctx, http.MethodPost, url, nil, resp, nil) + response, err := c.client.Request(ctx, http.MethodPost, url, params, resp, nil) if err != nil { return nil, err }