Reimplement default values (#65)
This commit is contained in:
130
core.go
130
core.go
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -32,51 +33,103 @@ func (r Response) CopyCommonResponse(rhs *CommonResponse) {
|
||||
rhs.Message = r.Message
|
||||
}
|
||||
|
||||
func getDefaultValues(v interface{}) (map[string]string, error) {
|
||||
isNil, err := isZero(v)
|
||||
if err != nil {
|
||||
return make(map[string]string), err
|
||||
}
|
||||
|
||||
if isNil {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
out := make(map[string]string)
|
||||
|
||||
vType := reflect.TypeOf(v).Elem()
|
||||
vValue := reflect.ValueOf(v).Elem()
|
||||
func getDefaultValues(v reflect.Value) error {
|
||||
vValue := v.Elem()
|
||||
vType := vValue.Type()
|
||||
|
||||
for i := 0; i < vType.NumField(); i++ {
|
||||
field := vType.Field(i)
|
||||
tag := field.Tag.Get("json")
|
||||
defaultValue := field.Tag.Get("default")
|
||||
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
// Attach any slices as query params
|
||||
fieldVal := vValue.Field(i)
|
||||
for j := 0; j < fieldVal.Len(); j++ {
|
||||
out[tag] = fmt.Sprintf("%v", fieldVal.Index(j))
|
||||
}
|
||||
} else {
|
||||
// Add any scalar values as query params
|
||||
fieldVal := fmt.Sprintf("%v", vValue.Field(i))
|
||||
|
||||
// If no value was set by the user, use the default
|
||||
// value specified in the struct tag.
|
||||
if fieldVal == "" || fieldVal == "0" {
|
||||
if defaultValue == "" {
|
||||
switch field.Type.Kind() {
|
||||
case reflect.Slice:
|
||||
for j := 0; j < vValue.Field(i).Len(); j++ {
|
||||
// skip if slice type is primitive
|
||||
if vValue.Field(i).Index(j).Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldVal = defaultValue
|
||||
// Attach any slices as query params
|
||||
err := getDefaultValues(vValue.Field(i).Index(j).Addr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
isNil, err := isZero(vValue.Field(i).Addr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !isNil {
|
||||
continue
|
||||
}
|
||||
|
||||
out[tag] = fmt.Sprintf("%v", fieldVal)
|
||||
defaultValue, ok := field.Tag.Lookup("default")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
vValue.Field(i).SetString(defaultValue)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
isNil, err := isZero(vValue.Field(i).Addr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !isNil {
|
||||
continue
|
||||
}
|
||||
|
||||
defaultValue, ok := field.Tag.Lookup("default")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
defaultValueInt, err := strconv.Atoi(defaultValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vValue.Field(i).SetInt(int64(defaultValueInt))
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
isNil, err := isZero(vValue.Field(i).Addr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !isNil {
|
||||
continue
|
||||
}
|
||||
|
||||
defaultValue, ok := field.Tag.Lookup("default")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
defaultValueUint, err := strconv.ParseUint(defaultValue, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vValue.Field(i).SetUint(uint64(defaultValueUint))
|
||||
case reflect.Struct:
|
||||
err := getDefaultValues(vValue.Field(i).Addr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Ptr:
|
||||
isNil, err := isZero(vValue.Field(i).Addr())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isNil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := getDefaultValues(vValue.Field(i)); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildRawQuery(req *http.Request, v interface{}) (string, error) {
|
||||
@@ -86,23 +139,20 @@ func buildRawQuery(req *http.Request, v interface{}) (string, error) {
|
||||
return query.Encode(), nil
|
||||
}
|
||||
|
||||
values, err := getDefaultValues(v)
|
||||
err := getDefaultValues(reflect.ValueOf(v))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for k, v := range values {
|
||||
query.Add(k, v)
|
||||
}
|
||||
|
||||
return query.Encode(), nil
|
||||
}
|
||||
|
||||
func isZero(v interface{}) (bool, error) {
|
||||
t := reflect.TypeOf(v)
|
||||
func isZero(v reflect.Value) (bool, error) {
|
||||
t := v.Elem().Type()
|
||||
if !t.Comparable() {
|
||||
return false, fmt.Errorf("type is not comparable: %v", t)
|
||||
}
|
||||
return v == reflect.Zero(t).Interface(), nil
|
||||
return reflect.Zero(t).Equal(v.Elem()), nil
|
||||
}
|
||||
|
||||
func TimeFromString(t *testing.T, format, datetime string) time.Time {
|
||||
|
||||
Reference in New Issue
Block a user