package main import ( "encoding/base64" "encoding/json" "errors" "fmt" "io" "log" "net/http" "os" "strconv" "strings" "time" ) type Response struct { Message string } type Request struct { User string Data string Payload string Signature string } func parseRequest(signedMessage string) (Request, error) { var req Request parts := strings.Split(signedMessage, ".") if len(parts) != 2 { return req, errors.New("request doesn't contain exactly two parts") } payload := parts[0] signature := parts[1] requestBody, err := base64.StdEncoding.DecodeString(payload) if err != nil { return req, err } err = json.Unmarshal(requestBody, &req) if err != nil { return req, errors.New("unable to parse request") } req.Payload = payload req.Signature = signature return req, nil } func handleAuthentication(req Request) (bool, error) { userRegistered, dbError := db.checkUserRegistered(req.User) if dbError != nil { return false, dbError } if !userRegistered { return false, nil } key, err := db.getUserKey(req.User) if err != nil { return false, err } check, err := checkSignature(req.Payload, req.Signature, key) if err != nil { return false, err } return check, nil } func jsonResponse(w http.ResponseWriter, resp Response) error { data, err := json.Marshal(resp) if err != nil { return err } _, err = w.Write(data) return err } func methodNotAllowedResponse(w http.ResponseWriter) error { w.WriteHeader(http.StatusMethodNotAllowed) return jsonResponse(w, Response{ Message: "К серверу разрешены обращения только по методу POST", }) } func badRequest(w http.ResponseWriter) error { w.WriteHeader(http.StatusBadRequest) return jsonResponse(w, Response{ Message: "Не удалось обработать запрос", }) } func serverError(w http.ResponseWriter) error { w.WriteHeader(http.StatusInternalServerError) return jsonResponse(w, Response{ Message: "Произошла внутренняя ошибка", }) } // Views func register(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { _ = methodNotAllowedResponse(w) return } requestBody, _ := io.ReadAll(r.Body) signedMessage := string(requestBody) req, err := parseRequest(signedMessage) if err != nil { log.Printf("[Register] Got bad request\n") _ = badRequest(w) return } userRegistered, err := db.checkUserRegistered(req.User) if err != nil { log.Printf("[Register] (%s) %s\n", req.User, err) w.WriteHeader(http.StatusInternalServerError) _ = jsonResponse(w, Response{ Message: fmt.Sprintf("%s", err), }) return } if userRegistered { log.Printf("[Register] (%s) Already registered user tried to register\n", req.User) w.WriteHeader(http.StatusBadRequest) _ = jsonResponse(w, Response{ Message: "Пользователь с таким именем уже зарегистрирован", }) return } checkResult, err := checkSignature(req.Payload, req.Signature, req.Data) if err != nil || !checkResult { log.Printf("[Register] (%s) Provided key is not valid\n", req.User) w.WriteHeader(http.StatusBadRequest) _ = jsonResponse(w, Response{ Message: "Указанный ключ не является действительным", }) return } err = db.registerUser(req.User, req.Data) if err != nil { log.Printf("[Register] (%s) %s\n", req.User, err) w.WriteHeader(http.StatusInternalServerError) _ = jsonResponse(w, Response{ Message: fmt.Sprintf("%s", err), }) return } log.Printf("[Register] (%s) User successfully registered\n", req.User) _ = jsonResponse(w, Response{ Message: "Пользователь успешно зарегистрирован", }) } func sendMessage(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { _ = methodNotAllowedResponse(w) return } requestBody, _ := io.ReadAll(r.Body) signedMessage := string(requestBody) req, err := parseRequest(signedMessage) if err != nil { log.Printf("[SendMessage] Got bad request\n") _ = badRequest(w) return } authComplete, dbErr := handleAuthentication(req) if !authComplete { log.Printf("[SendMessage] (%s) Request failed verification\n", req.User) w.WriteHeader(http.StatusForbidden) _ = jsonResponse(w, Response{ Message: "Запрос не прошёл аутентификацию", }) return } if dbErr != nil { log.Printf("[SendMessage] (%s) %s\n", req.User, dbErr) _ = serverError(w) return } timestamp := time.Now().UnixNano() err = db.saveMessage(signedMessage, timestamp) if err != nil { log.Printf("[SendMessage] (%s) %s\n", req.User, err) w.WriteHeader(http.StatusInternalServerError) _ = jsonResponse(w, Response{ Message: fmt.Sprintf("%s", err), }) return } log.Printf("[SendMessage] (%s) Saved message: %s\n", req.User, req.Data) _ = jsonResponse(w, Response{ Message: "Сообщение успешно сохранено", }) } func pollMessages(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { _ = methodNotAllowedResponse(w) return } requestBody, _ := io.ReadAll(r.Body) signedMessage := string(requestBody) req, err := parseRequest(signedMessage) if err != nil { log.Printf("[PollMessages] Got bad request\n") _ = badRequest(w) return } authComplete, dbErr := handleAuthentication(req) if !authComplete { log.Printf("[PollMessages] (%s) Request failed verification\n", req.User) w.WriteHeader(http.StatusForbidden) _ = jsonResponse(w, Response{ Message: "Запрос не прошёл аутентификацию", }) return } if dbErr != nil { log.Printf("[PollMessages] (%s) %s\n", req.User, dbErr) _ = serverError(w) return } timestamp, err := strconv.ParseInt(req.Data, 10, 64) if err != nil { log.Printf("[PollMessages] (%s) Got bad timestamp (%s)\n", req.User, req.Data) w.WriteHeader(http.StatusBadRequest) _ = jsonResponse(w, Response{ Message: "Указана неверная временная отметка", }) return } messages, err := db.getMessagesSince(timestamp) if err != nil { log.Printf("[PollMessages] (%s) %s\n", req.User, err) w.WriteHeader(http.StatusBadRequest) _ = jsonResponse(w, Response{ Message: fmt.Sprint(err), }) //_ = serverError(w) return } //log.Printf("[PollMessages] (%s) Sent messages to client\n", req.User) data, _ := json.Marshal(messages) _ = jsonResponse(w, Response{ Message: string(data), }) } func getUserKey(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { _ = methodNotAllowedResponse(w) return } requestBody, _ := io.ReadAll(r.Body) signedMessage := string(requestBody) req, err := parseRequest(signedMessage) if err != nil { log.Printf("[GetUserKey] Got bad request\n") _ = badRequest(w) return } authComplete, dbErr := handleAuthentication(req) if !authComplete { log.Printf("[GetUserKey] (%s) Request failed verification\n", req.User) w.WriteHeader(http.StatusForbidden) _ = jsonResponse(w, Response{ Message: "Запрос не прошёл аутентификацию", }) return } if dbErr != nil { log.Printf("[GetUserKey] (%s) %s\n", req.User, dbErr) _ = serverError(w) return } userRegistered, err := db.checkUserRegistered(req.Data) if !userRegistered { log.Printf("[GetUserKey] (%s) Priovided user (%s) is not registered\n", req.User, req.Data) w.WriteHeader(http.StatusBadRequest) _ = jsonResponse(w, Response{ Message: "Пользователь с таким именем не зарегистрирован", }) return } key, err := db.getUserKey(req.Data) if err != nil { log.Printf("[GetUserKey] (%s) %s\n", req.User, err) _ = serverError(w) return } log.Printf("[GetUserKey] (%s) Sent key of (%s) to client\n", req.User, req.Data) _ = jsonResponse(w, Response{ Message: key, }) } func tryAuth(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { _ = methodNotAllowedResponse(w) return } requestBody, _ := io.ReadAll(r.Body) signedMessage := string(requestBody) req, err := parseRequest(signedMessage) if err != nil { log.Printf("[TryAuth] Got bad request\n") _ = badRequest(w) return } authComplete, dbErr := handleAuthentication(req) if !authComplete { log.Printf("[TryAuth] (%s) Request failed verification\n", req.User) w.WriteHeader(http.StatusForbidden) _ = jsonResponse(w, Response{ Message: "Запрос не прошёл аутентификацию", }) return } if dbErr != nil { log.Printf("[TryAuth] (%s) %s\n", req.User, dbErr) _ = serverError(w) return } log.Printf("[TryAuth] (%s) Request passed verification\n", req.User) _ = jsonResponse(w, Response{ Message: "Запрос прошёл аутентификацию", }) } var db SQLConnection func parseArguments(args []string) (map[string]string, error) { const ( ArgName = iota ArgValue ) result := make(map[string]string) state := ArgName currentName := "" for _, arg := range args { if state == ArgName { if !strings.HasPrefix(arg, "--") { return nil, errors.New(fmt.Sprintf("expected argument name, got %s", arg)) } name := strings.TrimPrefix(arg, "--") state = ArgValue currentName = name } else if state == ArgValue { if strings.HasPrefix(arg, "--") { return nil, errors.New(fmt.Sprintf("expected argument value, got %s", arg)) } result[currentName] = arg state = ArgName } else { return nil, errors.New("invalid parser state") } } return result, nil } func main() { args, err := parseArguments(os.Args[1:]) if err != nil { fmt.Printf("Error while parsing arguments: %s\n", err) return } dbPath, ok := args["dbpath"] if !ok { dbPath = "chat.sqlite3" } err = db.init(dbPath) if err != nil { fmt.Println(err) os.Exit(1) } http.HandleFunc("/api/register", register) http.HandleFunc("/api/sendMessage", sendMessage) http.HandleFunc("/api/pollMessages", pollMessages) http.HandleFunc("/api/getUserKey", getUserKey) http.HandleFunc("/api/tryAuth", tryAuth) host, ok := args["host"] if !ok { host = "0.0.0.0" } port, ok := args["port"] if !ok { port = "12345" } addr := fmt.Sprintf("%s:%s", host, port) log.Printf("Started server on %s\n", addr) err = http.ListenAndServe(addr, nil) if err != nil { fmt.Println(err) os.Exit(1) } }