diff --git a/server/memory_db.go b/server/memory_db.go index 4dd914a..28f8aaa 100644 --- a/server/memory_db.go +++ b/server/memory_db.go @@ -1,6 +1,8 @@ package server import ( + "errors" + "fmt" "log" "os" "sync" @@ -13,86 +15,74 @@ import ( // In-Memory volatile Thread-Safe Key-Value Storage type MemoryDB struct { - table map[string]*Process - mu sync.Mutex -} - -// Inits the db with an empty map of string->Process pointer -func (m *MemoryDB) New() { - m.table = make(map[string]*Process) + table sync.Map } // Get a process pointer given its id -func (m *MemoryDB) Get(id string) *Process { - m.mu.Lock() - res := m.table[id] - m.mu.Unlock() - return res +func (m *MemoryDB) Get(id string) (*Process, error) { + entry, ok := db.table.Load(id) + if !ok { + return nil, errors.New("no process found for the given key") + } + return entry.(*Process), nil } // Store a pointer of a process and return its id func (m *MemoryDB) Set(process *Process) string { id := uuid.Must(uuid.NewRandom()).String() - m.mu.Lock() - m.table[id] = process - m.mu.Unlock() + db.table.Store(id, process) return id } // Update a process info/metadata, given the process id -func (m *MemoryDB) Update(id string, info DownloadInfo) { - m.mu.Lock() - if m.table[id] != nil { - m.table[id].Info = info +func (m *MemoryDB) UpdateInfo(id string, info DownloadInfo) error { + entry, ok := db.table.Load(id) + if ok { + entry.(*Process).Info = info + db.table.Store(id, entry) + return nil } - m.mu.Unlock() + return fmt.Errorf("can't update row with id %s", id) } // Update a process progress data, given the process id // Used for updating completition percentage or ETA -func (m *MemoryDB) UpdateProgress(id string, progress DownloadProgress) { - m.mu.Lock() - if m.table[id] != nil { - m.table[id].Progress = progress +func (m *MemoryDB) UpdateProgress(id string, progress DownloadProgress) error { + entry, ok := db.table.Load(id) + if ok { + entry.(*Process).Progress = progress + db.table.Store(id, entry) + return nil } - m.mu.Unlock() + return fmt.Errorf("can't update row with id %s", id) } // Removes a process progress, given the process id func (m *MemoryDB) Delete(id string) { - m.mu.Lock() - delete(m.table, id) - m.mu.Unlock() + db.table.Delete(id) } -// Returns a slice of all currently stored processes id -func (m *MemoryDB) Keys() []string { - m.mu.Lock() - keys := make([]string, len(m.table)) - i := 0 - for k := range m.table { - keys[i] = k - i++ - } - m.mu.Unlock() - return keys +func (m *MemoryDB) Keys() *[]string { + running := []string{} + db.table.Range(func(key, value any) bool { + running = append(running, key.(string)) + return true + }) + return &running } // Returns a slice of all currently stored processes progess -func (m *MemoryDB) All() []ProcessResponse { - running := make([]ProcessResponse, len(m.table)) - i := 0 - for k, v := range m.table { - if v != nil { - running[i] = ProcessResponse{ - Id: k, - Info: v.Info, - Progress: v.Progress, - } - i++ - } - } - return running +func (m *MemoryDB) All() *[]ProcessResponse { + running := []ProcessResponse{} + db.table.Range(func(key, value any) bool { + running = append(running, ProcessResponse{ + Id: key.(string), + Info: value.(*Process).Info, + Progress: value.(*Process).Progress, + }) + return true + }) + return &running } // WIP: Persist the database in a single file named "session.dat" @@ -100,7 +90,7 @@ func (m *MemoryDB) Persist() { running := m.All() session, err := json.Marshal(Session{ - Processes: running, + Processes: *running, }) if err != nil { log.Println(cli.Red, "Failed to persist database", cli.Reset) diff --git a/server/process.go b/server/process.go index a522b27..8f0f3c8 100644 --- a/server/process.go +++ b/server/process.go @@ -118,7 +118,7 @@ func (p *Process) Start(path, filename string) { } info := DownloadInfo{URL: p.url} json.Unmarshal(stdout, &info) - p.mem.Update(p.id, info) + p.mem.UpdateInfo(p.id, info) }() // --------------- progress block --------------- // diff --git a/server/rx/extensions.go b/server/rx/extensions.go index 9886712..5c5880c 100644 --- a/server/rx/extensions.go +++ b/server/rx/extensions.go @@ -18,7 +18,7 @@ import "time" // -t-> |> // // --A-----C-----G-------|> -func Debounce(interval time.Duration, source chan string, cb func(emit string)) { +func Debounce(interval time.Duration, source chan string, f func(emit string)) { var item string timer := time.NewTimer(interval) for { @@ -27,7 +27,7 @@ func Debounce(interval time.Duration, source chan string, cb func(emit string)) timer.Reset(interval) case <-timer.C: if item != "" { - cb(item) + f(item) } } } diff --git a/server/server.go b/server/server.go index 7a9f548..9e3f419 100644 --- a/server/server.go +++ b/server/server.go @@ -17,10 +17,6 @@ import ( var db MemoryDB -func init() { - db.New() -} - func RunBlocking(ctx context.Context) { fe := ctx.Value("frontend").(fs.SubFS) port := ctx.Value("port").(int) diff --git a/server/service.go b/server/service.go index fcc0809..6d2de79 100644 --- a/server/service.go +++ b/server/service.go @@ -40,7 +40,11 @@ func (t *Service) Exec(args DownloadSpecificArgs, result *string) error { // Progess retrieves the Progress of a specific Process given its Id func (t *Service) Progess(args Args, progress *DownloadProgress) error { - *progress = db.Get(args.Id).Progress + proc, err := db.Get(args.Id) + if err != nil { + return err + } + *progress = proc.Progress return nil } @@ -54,21 +58,23 @@ func (t *Service) Formats(args Args, progress *DownloadFormats) error { // Pending retrieves a slice of all Pending/Running processes ids func (t *Service) Pending(args NoArgs, pending *Pending) error { - *pending = Pending(db.Keys()) + *pending = *db.Keys() return nil } // Running retrieves a slice of all Processes progress func (t *Service) Running(args NoArgs, running *Running) error { - *running = db.All() + *running = *db.All() return nil } // Kill kills a process given its id and remove it from the memoryDB func (t *Service) Kill(args string, killed *string) error { log.Println("Trying killing process with id", args) - proc := db.Get(args) - var err error + proc, err := db.Get(args) + if err != nil { + return err + } if proc != nil { err = proc.Kill() } @@ -81,8 +87,11 @@ func (t *Service) KillAll(args NoArgs, killed *string) error { log.Println("Killing all spawned processes", args) keys := db.Keys() var err error - for _, key := range keys { - proc := db.Get(key) + for _, key := range *keys { + proc, err := db.Get(key) + if err != nil { + return err + } if proc != nil { proc.Kill() }