diff --git a/frontend/src/views/Archive.tsx b/frontend/src/views/Archive.tsx index 628c2ab..bb95aaa 100644 --- a/frontend/src/views/Archive.tsx +++ b/frontend/src/views/Archive.tsx @@ -9,6 +9,7 @@ import { DialogContent, DialogContentText, DialogTitle, + IconButton, List, ListItem, ListItemButton, @@ -39,6 +40,8 @@ import { useI18n } from '../hooks/useI18n' import { ffetch } from '../lib/httpClient' import { DeleteRequest, DirectoryEntry } from '../types' import { base64URLEncode, roundMiB } from '../utils' +import DownloadIcon from '@mui/icons-material/Download' + export default function Downloaded() { const serverAddr = useRecoilValue(serverURL) @@ -119,7 +122,7 @@ export default function Downloaded() { combineLatestWith(selected$), map(([data, selected]) => data.map(x => ({ ...x, - selected: selected.includes(x.name) + selected: selected.includes(x.path) }))), share() ), []) @@ -155,7 +158,13 @@ export default function Downloaded() { const onFileClick = (path: string) => startTransition(() => { const encoded = base64URLEncode(path) - window.open(`${serverAddr}/archive/d/${encoded}`) + window.open(`${serverAddr}/archive/v/${encoded}?token=${localStorage.getItem('token')}`) + }) + + const downloadFile = (path: string) => startTransition(() => { + const encoded = base64URLEncode(path) + + window.open(`${serverAddr}/archive/d/${encoded}?token=${localStorage.getItem('token')}`) }) const onFolderClick = (path: string) => startTransition(() => { @@ -192,11 +201,20 @@ export default function Downloaded() { {roundMiB(file.size)} } - {!file.isDirectory && addSelected(file.name)} - />} + {!file.isDirectory && <> + downloadFile(file.path)} + sx={{ marginLeft: 1.5 }} + > + + + addSelected(file.path)} + /> + } } disablePadding diff --git a/server/handlers/archive.go b/server/handlers/archive.go index f4e057e..dfcc67b 100644 --- a/server/handlers/archive.go +++ b/server/handlers/archive.go @@ -3,6 +3,7 @@ package handlers import ( "encoding/base64" "encoding/json" + "io" "net/http" "net/url" "os" @@ -136,14 +137,55 @@ func SendFile(w http.ResponseWriter, r *http.Request) { // TODO: further path / file validations if strings.Contains(filepath.Dir(filename), root) { - w.Header().Add( - "Content-Disposition", - "inline; filename="+filepath.Base(filename), - ) - http.ServeFile(w, r, filename) return } w.WriteHeader(http.StatusUnauthorized) } + +func DownloadFile(w http.ResponseWriter, r *http.Request) { + path := chi.URLParam(r, "id") + + if path == "" { + http.Error(w, "inexistent path", http.StatusBadRequest) + return + } + + path, err := url.QueryUnescape(path) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + decoded, err := base64.StdEncoding.DecodeString(path) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + filename := string(decoded) + + root := config.Instance().DownloadPath + + if strings.Contains(filepath.Dir(filename), root) { + w.Header().Add( + "Content-Disposition", + "inline; filename="+filepath.Base(filename), + ) + w.Header().Set( + "Content-Type", + "application/octet-stream", + ) + + fd, err := os.Open(filename) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + io.Copy(w, fd) + } + + w.WriteHeader(http.StatusUnauthorized) +} diff --git a/server/middleware/jwt.go b/server/middleware/jwt.go index 6a77ad8..d7de94f 100644 --- a/server/middleware/jwt.go +++ b/server/middleware/jwt.go @@ -42,19 +42,9 @@ func validateToken(tokenValue string) error { func Authenticated(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := r.Header.Get("X-Authentication") - - if err := validateToken(token); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - next.ServeHTTP(w, r) - }) -} - -func WebSocketAuthentication(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := r.URL.Query().Get("token") + if token == "" { + token = r.URL.Query().Get("token") + } if err := validateToken(token); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/server/rpc/container.go b/server/rpc/container.go index 8bd7586..1528c8d 100644 --- a/server/rpc/container.go +++ b/server/rpc/container.go @@ -18,18 +18,10 @@ func Container(db *internal.MemoryDB, mq *internal.MessageQueue) *Service { // RPC service must be registered before applying this router! func ApplyRouter() func(chi.Router) { return func(r chi.Router) { - r.Route("/ws", func(r chi.Router) { - if config.Instance().RequireAuth { - r.Use(middlewares.WebSocketAuthentication) - } - r.Get("/", WebSocket) - }) - - r.Route("/http", func(r chi.Router) { - if config.Instance().RequireAuth { - r.Use(middlewares.Authenticated) - } - r.Post("/", Post) - }) + if config.Instance().RequireAuth { + r.Use(middlewares.Authenticated) + } + r.Get("/ws", WebSocket) + r.Post("/http", Post) } } diff --git a/server/server.go b/server/server.go index bd1a3b7..216fe13 100644 --- a/server/server.go +++ b/server/server.go @@ -102,7 +102,8 @@ func newServer(c serverConfig) *http.Server { } r.Post("/downloaded", handlers.ListDownloaded) r.Post("/delete", handlers.DeleteFile) - r.Get("/d/{id}", handlers.SendFile) + r.Get("/d/{id}", handlers.DownloadFile) + r.Get("/v/{id}", handlers.SendFile) }) // Authentication routes