complement/should/should.go

290 lines
9.4 KiB
Go

// Package should contains assertions for tests, which returns an error if the assertion fails.
package should
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/tidwall/gjson"
"golang.org/x/exp/slices"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/complement/client"
"github.com/matrix-org/complement/match"
)
// EXPERIMENTAL
// ParseJSON will ensure that the HTTP request/response body is valid JSON, then return the body, else returns an error.
func ParseJSON(b io.ReadCloser) (res gjson.Result, err error) {
body, err := io.ReadAll(b)
if err != nil {
return res, fmt.Errorf("ParseJSON: reading body returned %s", err)
}
if !gjson.ValidBytes(body) {
return res, fmt.Errorf("ParseJSON: not valid JSON")
}
return gjson.ParseBytes(body), nil
}
// EXPERIMENTAL
// MatchRequest consumes the HTTP request and performs HTTP-level assertions on it. Returns the raw response body.
func MatchRequest(req *http.Request, m match.HTTPRequest) ([]byte, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("MatchRequest: Failed to read request body: %s", err)
}
contextStr := fmt.Sprintf("%s => %s", req.URL.String(), string(body))
if m.Headers != nil {
for name, val := range m.Headers {
if req.Header.Get(name) != val {
return nil, fmt.Errorf("MatchRequest got %s: %s want %s - %s", name, req.Header.Get(name), val, contextStr)
}
}
}
if m.JSON != nil {
if !gjson.ValidBytes(body) {
return nil, fmt.Errorf("MatchRequest request body is not valid JSON - %s", contextStr)
}
parsedBody := gjson.ParseBytes(body)
for _, jm := range m.JSON {
if err = jm(parsedBody); err != nil {
return nil, fmt.Errorf("MatchRequest %s - %s", err, contextStr)
}
}
}
return body, nil
}
// EXPERIMENTAL
// MatchSuccess consumes the HTTP response and fails if the response is non-2xx.
func MatchSuccess(res *http.Response) error {
if res.StatusCode < 200 || res.StatusCode > 299 {
return fmt.Errorf("MatchSuccess got status %d instead of a success code", res.StatusCode)
}
return nil
}
// EXPERIMENTAL
// MatchFailure consumes the HTTP response and fails if the response is 2xx.
func MatchFailure(res *http.Response) error {
if res.StatusCode >= 200 && res.StatusCode <= 299 {
return fmt.Errorf("MatchFailure got status %d instead of a failure code", res.StatusCode)
}
return nil
}
// EXPERIMENTAL
// MatchResponse consumes the HTTP response and performs HTTP-level assertions on it. Returns the raw response body.
func MatchResponse(res *http.Response, m match.HTTPResponse) ([]byte, error) {
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("MatchResponse: Failed to read response body: %s", err)
}
contextStr := fmt.Sprintf("%s => %s", res.Request.URL.String(), string(body))
if m.StatusCode != 0 {
if res.StatusCode != m.StatusCode {
return nil, fmt.Errorf("MatchResponse got status %d want %d - %s", res.StatusCode, m.StatusCode, contextStr)
}
}
if m.Headers != nil {
for name, val := range m.Headers {
if res.Header.Get(name) != val {
return nil, fmt.Errorf("MatchResponse got %s: %s want %s - %s", name, res.Header.Get(name), val, contextStr)
}
}
}
if m.JSON != nil {
if !gjson.ValidBytes(body) {
return nil, fmt.Errorf("MatchResponse response body is not valid JSON - %s", contextStr)
}
parsedBody := gjson.ParseBytes(body)
for _, jm := range m.JSON {
if err = jm(parsedBody); err != nil {
return nil, fmt.Errorf("MatchResponse %s - %s", err, contextStr)
}
}
}
return body, nil
}
// MatchFederationRequest performs JSON assertions on incoming federation requests.
func MatchFederationRequest(fedReq *fclient.FederationRequest, matchers ...match.JSON) error {
content := fedReq.Content()
if !gjson.ValidBytes(content) {
return fmt.Errorf("MatchFederationRequest content is not valid JSON - %s", fedReq.RequestURI())
}
parsedContent := gjson.ParseBytes(content)
for _, jm := range matchers {
if err := jm(parsedContent); err != nil {
return fmt.Errorf("MatchFederationRequest %s - %s", err, fedReq.RequestURI())
}
}
return nil
}
// EXPERIMENTAL
// MatchGJSON performs JSON assertions on a gjson.Result object.
func MatchGJSON(jsonResult gjson.Result, matchers ...match.JSON) error {
return MatchJSONBytes([]byte(jsonResult.Raw), matchers...)
}
// EXPERIMENTAL
// MatchJSONBytes performs JSON assertions on a raw json byte slice.
func MatchJSONBytes(rawJson []byte, matchers ...match.JSON) error {
if !gjson.ValidBytes(rawJson) {
return fmt.Errorf("MatchJSONBytes: rawJson is not valid JSON")
}
body := gjson.ParseBytes(rawJson)
for _, jm := range matchers {
if err := jm(body); err != nil {
return fmt.Errorf("MatchJSONBytes %s with input = %v", err, body.Get("@pretty").String())
}
}
return nil
}
// EXPERIMENTAL
// GetJSONFieldStr extracts the string value under `wantKey` or fails the test.
// The format of `wantKey` is specified at https://godoc.org/github.com/tidwall/gjson#Get
func GetJSONFieldStr(body gjson.Result, wantKey string) (string, error) {
res := body.Get(wantKey)
if res.Index == 0 {
return "", fmt.Errorf("JSONFieldStr: key '%s' missing from %s", wantKey, body.Raw)
}
if res.Str == "" {
return "", fmt.Errorf("JSONFieldStr: key '%s' is not a string, body: %s", wantKey, body.Raw)
}
return res.Str, nil
}
// EXPERIMENTAL
// HaveInOrder checks that the two slices match exactly, failing the test on mismatches or omissions.
func HaveInOrder[V comparable](gots []V, wants []V) error {
if len(gots) != len(wants) {
return fmt.Errorf("HaveInOrder: length mismatch, got %v want %v", gots, wants)
}
for i := range gots {
if gots[i] != wants[i] {
return fmt.Errorf("HaveInOrder: index %d got %v want %v", i, gots[i], wants[i])
}
}
return nil
}
// EXPERIMENTAL
// ContainSubset checks that every item in smaller is in larger, failing the test if at least 1 item isn't. Ignores extra elements
// in larger. Ignores ordering.
func ContainSubset[V comparable](larger []V, smaller []V) error {
if len(larger) < len(smaller) {
return fmt.Errorf("ContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller))
}
for i, item := range smaller {
if !slices.Contains(larger, item) {
return fmt.Errorf("ContainSubset: element not found in larger set: smaller[%d] (%v) larger=%v", i, item, larger)
}
}
return nil
}
// EXPERIMENTAL
// NotContainSubset checks that every item in smaller is NOT in larger, failing the test if at least 1 item is. Ignores extra elements
// in larger. Ignores ordering.
func NotContainSubset[V comparable](larger []V, smaller []V) error {
if len(larger) < len(smaller) {
return fmt.Errorf("NotContainSubset: length mismatch, larger=%d smaller=%d", len(larger), len(smaller))
}
for i, item := range smaller {
if slices.Contains(larger, item) {
return fmt.Errorf("NotContainSubset: element found in larger set: smaller[%d] (%v)", i, item)
}
}
return nil
}
// EXPERIMENTAL
// GetTimelineEventIDs returns the timeline event IDs in the sync response for the given room ID. If the room is missing
// this returns a 0 element slice.
func GetTimelineEventIDs(topLevelSyncJSON gjson.Result, roomID string) []string {
timeline := topLevelSyncJSON.Get(fmt.Sprintf("rooms.join.%s.timeline.events", client.GjsonEscape(roomID))).Array()
eventIDs := make([]string, len(timeline))
for i := range timeline {
eventIDs[i] = timeline[i].Get("event_id").Str
}
return eventIDs
}
// EXPERIMENTAL
// CheckOffAll checks that a list contains exactly the given items, in any order.
//
// if an item is not present, an error is returned.
// if an item not present in the want list is present, an error is returned.
// Items are compared using match.JSONDeepEqual
func CheckOffAll(items []interface{}, wantItems []interface{}) error {
remaining, err := CheckOffAllAllowUnwanted(items, wantItems)
if err != nil {
return err
}
if len(remaining) > 0 {
return fmt.Errorf("CheckOffAll: unexpected items %v", remaining)
}
return nil
}
// EXPERIMENTAL
// CheckOffAllAllowUnwanted checks that a list contains all of the given items, in any order.
// The updated list with the matched items removed from it is returned.
//
// if an item is not present, an error is returned
// Items are compared using match.JSONDeepEqual
func CheckOffAllAllowUnwanted(items []interface{}, wantItems []interface{}) ([]interface{}, error) {
var err error
for _, wantItem := range wantItems {
items, err = CheckOff(items, wantItem)
if err != nil {
return nil, err
}
}
return items, nil
}
// EXPERIMENTAL
// CheckOff an item from the list. If the item is not present an error is returned
// The updated list with the matched item removed from it is returned. Items are
// compared using JSON deep equal.
func CheckOff(items []interface{}, wantItem interface{}) ([]interface{}, error) {
// check off the item
want := -1
for i, w := range items {
wBytes, _ := json.Marshal(w)
if jsonDeepEqual(wBytes, wantItem) {
want = i
break
}
}
if want == -1 {
return nil, fmt.Errorf("CheckOff: item %s not present", wantItem)
}
// delete the wanted item
items = append(items[:want], items[want+1:]...)
return items, nil
}
func jsonDeepEqual(gotJson []byte, wantValue interface{}) bool {
// marshal what the test gave us
wantBytes, _ := json.Marshal(wantValue)
// re-marshal what the network gave us to acount for key ordering
var gotVal interface{}
_ = json.Unmarshal(gotJson, &gotVal)
gotBytes, _ := json.Marshal(gotVal)
return bytes.Equal(gotBytes, wantBytes)
}