diff --git a/go.mod b/go.mod index 4363e62..6d2ce46 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/google/uuid v1.1.2 github.com/gorilla/mux v1.8.0 github.com/jessevdk/go-flags v1.4.1-0.20200711081900-c17162fe8fd7 + github.com/stretchr/testify v1.4.0 golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a gopkg.in/yaml.v2 v2.3.0 ) diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..7df7d62 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,87 @@ +package server_test + +import ( + "bytes" + "context" + "encoding/json" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + + "nano-run/server" + "nano-run/services/meta" +) + +var tmpDir string + +func TestMain(main *testing.M) { + var err error + tmpDir, err = ioutil.TempDir("", "") + if err != nil { + log.Fatal(err) + } + code := main.Run() + _ = os.RemoveAll(tmpDir) + os.Exit(code) +} + +func testServer(t *testing.T, cfg server.Config, units map[string]server.Unit) *server.Server { + sub, err := ioutil.TempDir(tmpDir, "") + if !assert.NoError(t, err) { + t.Fatal("failed to create temp dir", err) + } + cfg.ConfigDirectory = filepath.Join(sub, "config") + cfg.WorkingDirectory = filepath.Join(sub, "data") + err = cfg.CreateDirs() + if !assert.NoError(t, err) { + t.Fatal("failed to create dirs", err) + } + + for name, unit := range units { + err = unit.SaveFile(filepath.Join(cfg.ConfigDirectory, name+".yaml")) + if !assert.NoError(t, err) { + t.Fatal("failed to create unit", name, ":", err) + } + } + + srv, err := cfg.Create(context.Background()) + if !assert.NoError(t, err) { + srv.Close() + t.Fatal("failed to create server") + } + return srv +} + +func Test_create(t *testing.T) { + srv := testServer(t, server.DefaultConfig(), map[string]server.Unit{ + "hello": { + Command: "echo hello world", + }, + }) + defer srv.Close() + + req := httptest.NewRequest(http.MethodPost, "/hello/", bytes.NewBufferString("hello world")) + res := httptest.NewRecorder() + srv.ServeHTTP(res, req) + assert.Equal(t, http.StatusSeeOther, res.Code) + assert.NotEmpty(t, res.Header().Get("X-Correlation-Id")) + assert.Equal(t, "/hello/"+res.Header().Get("X-Correlation-Id"), res.Header().Get("Location")) + requestID := res.Header().Get("X-Correlation-Id") + + infoURL := res.Header().Get("Location") + req = httptest.NewRequest(http.MethodGet, infoURL, nil) + res = httptest.NewRecorder() + srv.ServeHTTP(res, req) + assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, requestID, res.Header().Get("X-Correlation-Id")) + assert.Equal(t, "application/json", res.Header().Get("Content-Type")) + var info meta.Request + err := json.Unmarshal(res.Body.Bytes(), &info) + assert.NoError(t, err) +}