summaryrefslogtreecommitdiff
path: root/http-server/main.go
diff options
context:
space:
mode:
authorAndrew <saintruler@gmail.com>2021-05-13 21:42:24 +0400
committerAndrew <saintruler@gmail.com>2021-05-13 21:42:24 +0400
commit712c9f7153c59bc5487e781cdeab0a60dcfd6d6e (patch)
tree6300faa4bd7653841b574cb45c1605603679f454 /http-server/main.go
parent587ac4f7fecc417b4877c5f3c0fdefa58990b3c8 (diff)
Changed workflow of saving messages, so that server now saves signed messages.
Diffstat (limited to 'http-server/main.go')
-rw-r--r--http-server/main.go79
1 files changed, 44 insertions, 35 deletions
diff --git a/http-server/main.go b/http-server/main.go
index 7aaa619..03e8ed2 100644
--- a/http-server/main.go
+++ b/http-server/main.go
@@ -18,25 +18,18 @@ type Response struct {
}
type Request struct {
- User string
- Data string
-}
-
-type Message struct {
User string
Data string
- Timestamp int64
+ Payload string
+ Signature string
}
-func parseRequest(w http.ResponseWriter, r *http.Request) (Request, string, error) {
+func parseRequest(signedMessage string) (Request, error) {
var req Request
- body, _ := io.ReadAll(r.Body)
- bodyStr := string(body)
- parts := strings.Split(bodyStr, ".")
+ parts := strings.Split(signedMessage, ".")
if len(parts) != 2 {
- _ = badRequest(w)
- return req, "", errors.New("request doesn't contain exactly two parts")
+ return req, errors.New("request doesn't contain exactly two parts")
}
payload := parts[0]
@@ -44,15 +37,16 @@ func parseRequest(w http.ResponseWriter, r *http.Request) (Request, string, erro
requestBody, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
- return req, "", err
+ return req, err
}
err = json.Unmarshal(requestBody, &req)
if err != nil {
- _ = badRequest(w)
- return req, "", err
+ return req, errors.New("unable to parse request")
}
- return req, signature, nil
+ req.Payload = payload
+ req.Signature = signature
+ return req, nil
}
func jsonResponse(w http.ResponseWriter, resp Response) error {
@@ -91,8 +85,12 @@ func register(w http.ResponseWriter, r *http.Request) {
return
}
- req, signature, err := parseRequest(w, r)
+ requestBody, _ := io.ReadAll(r.Body)
+ signedMessage := string(requestBody)
+
+ req, err := parseRequest(signedMessage)
if err != nil {
+ _ = badRequest(w)
return
}
@@ -113,7 +111,7 @@ func register(w http.ResponseWriter, r *http.Request) {
return
}
- checkResult, err := checkSignature(req, signature, req.Data)
+ checkResult, err := checkSignature(req.Payload, req.Signature, req.Data)
if err != nil || !checkResult {
w.WriteHeader(http.StatusBadRequest)
_ = jsonResponse(w, Response{
@@ -136,7 +134,7 @@ func register(w http.ResponseWriter, r *http.Request) {
})
}
-func handleAuthentication(req Request, signature string) (bool, error) {
+func handleAuthentication(req Request) (bool, error) {
userRegistered, dbError := db.checkUserRegistered(req.User)
if dbError != nil {
return false, dbError
@@ -151,7 +149,7 @@ func handleAuthentication(req Request, signature string) (bool, error) {
return false, err
}
- check, err := checkSignature(req, signature, key)
+ check, err := checkSignature(req.Payload, req.Signature, key)
if err != nil {
return false, err
}
@@ -165,12 +163,16 @@ func sendMessage(w http.ResponseWriter, r *http.Request) {
return
}
- req, signature, err := parseRequest(w, r)
+ requestBody, _ := io.ReadAll(r.Body)
+ signedMessage := string(requestBody)
+
+ req, err := parseRequest(signedMessage)
if err != nil {
+ _ = badRequest(w)
return
}
- authComplete, dbErr := handleAuthentication(req, signature)
+ authComplete, dbErr := handleAuthentication(req)
if !authComplete {
w.WriteHeader(http.StatusForbidden)
_ = jsonResponse(w, Response{
@@ -183,15 +185,10 @@ func sendMessage(w http.ResponseWriter, r *http.Request) {
return
}
- msg := Message{
- User: req.User,
- Data: req.Data,
- Timestamp: time.Now().UnixNano(),
- }
-
fmt.Printf("Got message from %s: %s\n", req.User, req.Data)
- err = db.saveMessage(msg)
+ timestamp := time.Now().UnixNano()
+ err = db.saveMessage(signedMessage, timestamp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_ = jsonResponse(w, Response{
@@ -211,12 +208,16 @@ func pollMessages(w http.ResponseWriter, r *http.Request) {
return
}
- req, signature, err := parseRequest(w, r)
+ requestBody, _ := io.ReadAll(r.Body)
+ signedMessage := string(requestBody)
+
+ req, err := parseRequest(signedMessage)
if err != nil {
+ _ = badRequest(w)
return
}
- authComplete, dbErr := handleAuthentication(req, signature)
+ authComplete, dbErr := handleAuthentication(req)
if !authComplete {
w.WriteHeader(http.StatusForbidden)
_ = jsonResponse(w, Response{
@@ -260,12 +261,16 @@ func getUserKey(w http.ResponseWriter, r *http.Request) {
return
}
- req, signature, err := parseRequest(w, r)
+ requestBody, _ := io.ReadAll(r.Body)
+ signedMessage := string(requestBody)
+
+ req, err := parseRequest(signedMessage)
if err != nil {
+ _ = badRequest(w)
return
}
- authComplete, dbErr := handleAuthentication(req, signature)
+ authComplete, dbErr := handleAuthentication(req)
if !authComplete {
w.WriteHeader(http.StatusForbidden)
@@ -305,12 +310,16 @@ func tryAuth(w http.ResponseWriter, r *http.Request) {
return
}
- req, signature, err := parseRequest(w, r)
+ requestBody, _ := io.ReadAll(r.Body)
+ signedMessage := string(requestBody)
+
+ req, err := parseRequest(signedMessage)
if err != nil {
+ _ = badRequest(w)
return
}
- authComplete, dbErr := handleAuthentication(req, signature)
+ authComplete, dbErr := handleAuthentication(req)
if !authComplete {
w.WriteHeader(http.StatusForbidden)
_ = jsonResponse(w, Response{