diff options
| -rw-r--r-- | http-client/client.go | 41 | ||||
| -rw-r--r-- | http-server/cryptography.go | 7 | ||||
| -rw-r--r-- | http-server/database.go | 20 | ||||
| -rw-r--r-- | http-server/main.go | 79 |
4 files changed, 91 insertions, 56 deletions
diff --git a/http-client/client.go b/http-client/client.go index 211c299..a83659e 100644 --- a/http-client/client.go +++ b/http-client/client.go @@ -2,12 +2,14 @@ package main import ( "bytes" + "encoding/base64" "encoding/json" "errors" "fmt" "fyne.io/fyne/v2/data/binding" "io" "net/http" + "strings" "sync" "time" ) @@ -39,7 +41,8 @@ type Response struct { type Message struct { User string Data string - Timestamp string + Payload string + Signature string } func (msg *Message) toString() string { @@ -108,6 +111,33 @@ func pingServer(user UserData) error { return err } +func parseMessage(signedMessage string) (Message, error) { + var msg Message + + parts := strings.Split(signedMessage, ".") + if len(parts) != 2 { + return msg, errors.New("request doesn't contain exactly two parts") + } + + payload := parts[0] + signature := parts[1] + + messageBody, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return msg, err + } + + err = json.Unmarshal(messageBody, &msg) + if err != nil { + return msg, errors.New("unable to parse message") + } + msg.Payload = payload + msg.Signature = signature + return msg, nil +} + +// TODO(andrew): Заменить сбор новых сообщений через http запрос к серверу на +// получения обновлений через сокет от сервера. func runClient(user UserData) { var lastPoll int64 = 0 for { @@ -121,12 +151,13 @@ func runClient(user UserData) { body, _ := io.ReadAll(httpResp.Body) var resp Response _ = json.Unmarshal(body, &resp) - var messages []Message - _ = json.Unmarshal([]byte(resp.Message), &messages) + var signedMessages []string + _ = json.Unmarshal([]byte(resp.Message), &signedMessages) storage.Lock() - for _, msg := range messages { - fmt.Printf("Polled new message from %s: %s (%s)\n", msg.User, msg.Data, msg.Timestamp) + for _, signedMessage := range signedMessages { + msg, _ := parseMessage(signedMessage) + fmt.Printf("Polled new message from %s: %s\n", msg.User, msg.Data) _ = storage.binding.Append(msg.toString()) storage.messages = append(storage.messages, msg) } diff --git a/http-server/cryptography.go b/http-server/cryptography.go index 3340446..beaba71 100644 --- a/http-server/cryptography.go +++ b/http-server/cryptography.go @@ -6,7 +6,6 @@ import ( "crypto/sha256" "crypto/x509" "encoding/base64" - "encoding/json" "encoding/pem" "errors" "fmt" @@ -30,10 +29,8 @@ func decodeMessage(ciphertext []byte, stringKey string) ([]byte, error) { return plaintext, err } -func checkSignature(req Request, signature string, key string) (bool, error) { - reqBytes, _ := json.Marshal(req) - req64 := base64.StdEncoding.EncodeToString(reqBytes) - h := sha256.Sum256([]byte(req64)) +func checkSignature(payload string, signature string, key string) (bool, error) { + h := sha256.Sum256([]byte(payload)) requestHash := fmt.Sprintf("%x", h) decodedSign, err := base64.StdEncoding.DecodeString(signature) diff --git a/http-server/database.go b/http-server/database.go index e79b3e7..8f1d094 100644 --- a/http-server/database.go +++ b/http-server/database.go @@ -49,23 +49,21 @@ func (conn *SQLConnection) getUserKey(username string) (string, error) { } } -func (conn *SQLConnection) saveMessage(message Message) error { +func (conn *SQLConnection) saveMessage(signedData string, timestamp int64) error { var err error - query := `INSERT INTO messages (userId, data, timestamp) values ( - (SELECT id FROM users WHERE username = ?), ?, ? - )` - _, err = conn.db.Exec(query, message.User, message.Data, message.Timestamp) + query := `INSERT INTO messages (timestamp, signed_message) values (?, ?);` + _, err = conn.db.Exec(query, timestamp, signedData) return err } -func (conn *SQLConnection) getMessagesSince(timestamp int64) ([]Message, error) { +func (conn *SQLConnection) getMessagesSince(timestamp int64) ([]string, error) { var err error - var msg []Message + var msg []string - query := `SELECT username, data, timestamp FROM users JOIN messages - WHERE users.id = messages.userId AND timestamp > ? + query := `SELECT signed_message FROM messages + WHERE timestamp > ? ORDER BY timestamp;` result, err := conn.db.Query(query, timestamp) @@ -74,8 +72,8 @@ func (conn *SQLConnection) getMessagesSince(timestamp int64) ([]Message, error) } for result.Next() { - var message Message - _ = result.Scan(&message.User, &message.Data, &message.Timestamp) + var message string + _ = result.Scan(&message) msg = append(msg, message) } _ = result.Close() diff --git a/http-server/main.go b/http-server/main.go index 7aaa619..03e8ed2 100644 --- a/http-server/main.go +++ b/http-server/main.go @@ -18,25 +18,18 @@ type Response struct { } type Request struct { - User string - Data string -} - -type Message struct { User string Data string - Timestamp int64 + Payload string + Signature string } -func parseRequest(w http.ResponseWriter, r *http.Request) (Request, string, error) { +func parseRequest(signedMessage string) (Request, error) { var req Request - body, _ := io.ReadAll(r.Body) - bodyStr := string(body) - parts := strings.Split(bodyStr, ".") + parts := strings.Split(signedMessage, ".") if len(parts) != 2 { - _ = badRequest(w) - return req, "", errors.New("request doesn't contain exactly two parts") + return req, errors.New("request doesn't contain exactly two parts") } payload := parts[0] @@ -44,15 +37,16 @@ func parseRequest(w http.ResponseWriter, r *http.Request) (Request, string, erro requestBody, err := base64.StdEncoding.DecodeString(payload) if err != nil { - return req, "", err + return req, err } err = json.Unmarshal(requestBody, &req) if err != nil { - _ = badRequest(w) - return req, "", err + return req, errors.New("unable to parse request") } - return req, signature, nil + req.Payload = payload + req.Signature = signature + return req, nil } func jsonResponse(w http.ResponseWriter, resp Response) error { @@ -91,8 +85,12 @@ func register(w http.ResponseWriter, r *http.Request) { return } - req, signature, err := parseRequest(w, r) + requestBody, _ := io.ReadAll(r.Body) + signedMessage := string(requestBody) + + req, err := parseRequest(signedMessage) if err != nil { + _ = badRequest(w) return } @@ -113,7 +111,7 @@ func register(w http.ResponseWriter, r *http.Request) { return } - checkResult, err := checkSignature(req, signature, req.Data) + checkResult, err := checkSignature(req.Payload, req.Signature, req.Data) if err != nil || !checkResult { w.WriteHeader(http.StatusBadRequest) _ = jsonResponse(w, Response{ @@ -136,7 +134,7 @@ func register(w http.ResponseWriter, r *http.Request) { }) } -func handleAuthentication(req Request, signature string) (bool, error) { +func handleAuthentication(req Request) (bool, error) { userRegistered, dbError := db.checkUserRegistered(req.User) if dbError != nil { return false, dbError @@ -151,7 +149,7 @@ func handleAuthentication(req Request, signature string) (bool, error) { return false, err } - check, err := checkSignature(req, signature, key) + check, err := checkSignature(req.Payload, req.Signature, key) if err != nil { return false, err } @@ -165,12 +163,16 @@ func sendMessage(w http.ResponseWriter, r *http.Request) { return } - req, signature, err := parseRequest(w, r) + requestBody, _ := io.ReadAll(r.Body) + signedMessage := string(requestBody) + + req, err := parseRequest(signedMessage) if err != nil { + _ = badRequest(w) return } - authComplete, dbErr := handleAuthentication(req, signature) + authComplete, dbErr := handleAuthentication(req) if !authComplete { w.WriteHeader(http.StatusForbidden) _ = jsonResponse(w, Response{ @@ -183,15 +185,10 @@ func sendMessage(w http.ResponseWriter, r *http.Request) { return } - msg := Message{ - User: req.User, - Data: req.Data, - Timestamp: time.Now().UnixNano(), - } - fmt.Printf("Got message from %s: %s\n", req.User, req.Data) - err = db.saveMessage(msg) + timestamp := time.Now().UnixNano() + err = db.saveMessage(signedMessage, timestamp) if err != nil { w.WriteHeader(http.StatusInternalServerError) _ = jsonResponse(w, Response{ @@ -211,12 +208,16 @@ func pollMessages(w http.ResponseWriter, r *http.Request) { return } - req, signature, err := parseRequest(w, r) + requestBody, _ := io.ReadAll(r.Body) + signedMessage := string(requestBody) + + req, err := parseRequest(signedMessage) if err != nil { + _ = badRequest(w) return } - authComplete, dbErr := handleAuthentication(req, signature) + authComplete, dbErr := handleAuthentication(req) if !authComplete { w.WriteHeader(http.StatusForbidden) _ = jsonResponse(w, Response{ @@ -260,12 +261,16 @@ func getUserKey(w http.ResponseWriter, r *http.Request) { return } - req, signature, err := parseRequest(w, r) + requestBody, _ := io.ReadAll(r.Body) + signedMessage := string(requestBody) + + req, err := parseRequest(signedMessage) if err != nil { + _ = badRequest(w) return } - authComplete, dbErr := handleAuthentication(req, signature) + authComplete, dbErr := handleAuthentication(req) if !authComplete { w.WriteHeader(http.StatusForbidden) @@ -305,12 +310,16 @@ func tryAuth(w http.ResponseWriter, r *http.Request) { return } - req, signature, err := parseRequest(w, r) + requestBody, _ := io.ReadAll(r.Body) + signedMessage := string(requestBody) + + req, err := parseRequest(signedMessage) if err != nil { + _ = badRequest(w) return } - authComplete, dbErr := handleAuthentication(req, signature) + authComplete, dbErr := handleAuthentication(req) if !authComplete { w.WriteHeader(http.StatusForbidden) _ = jsonResponse(w, Response{ |