From 31e20b242eca9358c31835e808340456fc8d4188 Mon Sep 17 00:00:00 2001 From: Alexander Baryshnikov Date: Thu, 17 Sep 2020 20:07:34 +0800 Subject: [PATCH] refactor server code for better testing --- server/internal/adapter.go | 1 + server/server.go | 85 +++++++++++++++++++++++++++----------- 2 files changed, 61 insertions(+), 25 deletions(-) diff --git a/server/internal/adapter.go b/server/internal/adapter.go index 11b4ac5..37c4d67 100644 --- a/server/internal/adapter.go +++ b/server/internal/adapter.go @@ -60,6 +60,7 @@ func getTask(wrk *worker.Worker) http.HandlerFunc { http.Error(writer, "encoding", http.StatusInternalServerError) return } + writer.Header().Set("X-Correlation-Id", requestID) writer.Header().Set("Content-Type", "application/json") writer.Header().Set("Content-Length", strconv.Itoa(len(data))) writer.Header().Set("Content-Version", strconv.Itoa(len(info.Attempts))) diff --git a/server/server.go b/server/server.go index 4e5d463..2c8df81 100644 --- a/server/server.go +++ b/server/server.go @@ -8,10 +8,11 @@ import ( "net/http" "os" "path/filepath" - "sync" "time" "gopkg.in/yaml.v2" + + "nano-run/worker" ) type Config struct { @@ -74,58 +75,62 @@ func (cfg Config) SaveFile(file string) error { return ioutil.WriteFile(file, data, 0600) } -func (cfg Config) Run(global context.Context) error { +func (cfg Config) Create(global context.Context) (*Server, error) { units, err := Units(cfg.ConfigDirectory) if err != nil { - return err + return nil, err } workers, err := Workers(cfg.WorkingDirectory, units) if err != nil { - return err + return nil, err } - defer func() { - for _, wrk := range workers { - wrk.Close() - } - }() - handler := Handler(units, workers) + ctx, cancel := context.WithCancel(global) + srv := &Server{ + Handler: Handler(units, workers), + workers: workers, + units: units, + done: make(chan struct{}), + cancel: cancel, + } + go srv.run(ctx) + return srv, nil +} + +func (cfg Config) Run(global context.Context) error { ctx, cancel := context.WithCancel(global) + defer cancel() + + srv, err := cfg.Create(global) + if err != nil { + return err + } + defer srv.Close() server := http.Server{ Addr: cfg.Bind, - Handler: handler, + Handler: srv, } - var wg sync.WaitGroup + done := make(chan struct{}) - wg.Add(1) go func() { - defer wg.Done() defer cancel() <-ctx.Done() t, c := context.WithTimeout(context.Background(), cfg.GracefulShutdown) _ = server.Shutdown(t) c() + close(done) }() - wg.Add(1) - go func() { - defer wg.Done() - defer cancel() - err := Run(ctx, workers) - if err != nil { - log.Println("workers stopped:", err) - } - }() if cfg.TLS.Enable { err = server.ListenAndServeTLS(cfg.TLS.Cert, cfg.TLS.Key) } else { err = server.ListenAndServe() } cancel() - wg.Wait() - return err + <-done + return ctx.Err() } func limitRequest(maxSize int64, handler http.Handler) http.Handler { @@ -142,3 +147,33 @@ func limitRequest(maxSize int64, handler http.Handler) http.Handler { handler.ServeHTTP(writer, request) }) } + +type Server struct { + http.Handler + workers []*worker.Worker + units []Unit + cancel func() + done chan struct{} + err error +} + +func (srv *Server) Close() { + for _, wrk := range srv.workers { + wrk.Close() + } + srv.cancel() + <-srv.done +} + +func (srv *Server) Err() error { + return srv.err +} + +func (srv *Server) run(ctx context.Context) { + err := Run(ctx, srv.workers) + if err != nil { + log.Println("workers stopped:", err) + } + srv.err = err + close(srv.done) +}