nano-run/server/ui/oauth.go
2020-10-02 12:40:18 +08:00

131 lines
2.9 KiB
Go

package ui
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/oauth2"
)
type OAuth2 struct {
Config oauth2.Config
ProfileURL string
RedirectTo string
LoginField string
}
func (cfg OAuth2) Attach(router gin.IRouter, storage SessionStorage) {
router.GET("/login", func(gctx *gin.Context) {
state := uuid.New().String()
gctx.SetCookie("oauth", state, 3600, "", "", false, true)
u := cfg.Config.AuthCodeURL(state)
gctx.Redirect(http.StatusTemporaryRedirect, u)
})
router.GET("/callback", func(gctx *gin.Context) {
savedState, _ := gctx.Cookie("oauth")
if savedState != gctx.Query("state") {
gctx.AbortWithStatus(http.StatusForbidden)
return
}
gctx.SetCookie("oauth", "", -1, "", "", false, true)
token, err := cfg.Config.Exchange(gctx.Request.Context(), gctx.Query("code"))
if err != nil {
_ = gctx.AbortWithError(http.StatusForbidden, err)
return
}
if !token.Valid() {
gctx.AbortWithStatus(http.StatusForbidden)
return
}
sessionID := uuid.New().String()
session := newOAuthSession(token)
if cfg.ProfileURL != "" {
err = session.fetchLogin(gctx.Request.Context(), cfg.ProfileURL, cfg.LoginField)
if err != nil {
_ = gctx.AbortWithError(http.StatusForbidden, err)
return
}
}
gctx.SetCookie(sessionCookie, sessionID, 0, "", "", false, true)
storage.Save(sessionID, session)
log.Println("user", session.Login(), "authorized via oauth2")
gctx.Redirect(http.StatusTemporaryRedirect, cfg.RedirectTo)
})
}
func newOAuthSession(token *oauth2.Token) *oauthSession {
return &oauthSession{
token: oauth2.StaticTokenSource(token),
}
}
type oauthSession struct {
login string
token oauth2.TokenSource
}
func (ss *oauthSession) Login() string {
return ss.login
}
func (ss *oauthSession) Valid() bool {
token, err := ss.token.Token()
if err != nil {
log.Println(err)
return false
}
return token.Valid()
}
func (ss *oauthSession) GetJSON(ctx context.Context, url string, response interface{}) error {
t, err := ss.token.Token()
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
t.SetAuthHeader(req)
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return errors.New(res.Status)
}
return json.NewDecoder(res.Body).Decode(response)
}
func (ss *oauthSession) fetchLogin(ctx context.Context, url string, field string) error {
var profile = make(map[string]interface{})
err := ss.GetJSON(ctx, url, &profile)
if err != nil {
return err
}
log.Printf("profile: %+v", profile)
l, ok := profile[field]
if !ok {
return fmt.Errorf("not field %s in response", field)
}
s, ok := l.(string)
if !ok {
return fmt.Errorf("field %s is not string", field)
}
ss.login = s
return nil
}