diff --git a/go.mod b/go.mod index aa3922f..fd2d748 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 github.com/reactivex/rxgo/v2 v2.5.0 + golang.org/x/sync v0.6.0 golang.org/x/sys v0.18.0 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.29.5 @@ -29,7 +30,6 @@ require ( github.com/stretchr/testify v1.9.0 // indirect github.com/teivah/onecontext v1.3.0 // indirect golang.org/x/net v0.22.0 // indirect - golang.org/x/sync v0.6.0 // indirect modernc.org/gc/v3 v3.0.0-20240304020402-f0dba7c97c2b // indirect modernc.org/libc v1.47.0 // indirect modernc.org/mathutil v1.6.0 // indirect diff --git a/server/internal/message_queue.go b/server/internal/message_queue.go index 79a2536..4c21e1a 100644 --- a/server/internal/message_queue.go +++ b/server/internal/message_queue.go @@ -1,18 +1,20 @@ package internal import ( + "context" "log/slog" evbus "github.com/asaskevich/EventBus" "github.com/marcopeocchi/yt-dlp-web-ui/server/config" + "golang.org/x/sync/semaphore" ) const queueName = "process:pending" type MessageQueue struct { - eventBus evbus.Bus - consumerCh chan struct{} - logger *slog.Logger + concurrency int + eventBus evbus.Bus + logger *slog.Logger } // Creates a new message queue. @@ -27,9 +29,9 @@ func NewMessageQueue(l *slog.Logger) *MessageQueue { } return &MessageQueue{ - eventBus: evbus.New(), - consumerCh: make(chan struct{}, qs), - logger: l, + concurrency: qs, + eventBus: evbus.New(), + logger: l, } } @@ -49,23 +51,25 @@ func (m *MessageQueue) SetupConsumers() { // Setup the consumer listener which subscribes to the changes to the producer // channel and triggers the "download" action. func (m *MessageQueue) downloadConsumer() { + sem := semaphore.NewWeighted(int64(m.concurrency)) + m.eventBus.SubscribeAsync(queueName, func(p *Process) { - m.consumerCh <- struct{}{} + //TODO: provide valid context + sem.Acquire(context.TODO(), 1) + defer sem.Release(1) m.logger.Info("received process from event bus", slog.String("bus", queueName), slog.String("consumer", "downloadConsumer"), - slog.String("id", p.Id), + slog.String("id", p.getShortId()), ) p.Start() m.logger.Info("started process", slog.String("bus", queueName), - slog.String("id", p.Id), + slog.String("id", p.getShortId()), ) - - <-m.consumerCh }, false) } @@ -78,21 +82,14 @@ func (m *MessageQueue) metadataSubscriber() { m.logger.Info("received process from event bus", slog.String("bus", queueName), slog.String("consumer", "metadataConsumer"), - slog.String("id", p.Id), + slog.String("id", p.getShortId()), ) if err := p.SetMetadata(); err != nil { m.logger.Error("failed to retrieve metadata", - slog.String("id", p.Id), + slog.String("id", p.getShortId()), slog.String("err", err.Error()), ) } }) } - -// Empties the message queue -func (m *MessageQueue) Empty() { - for range m.consumerCh { - <-m.consumerCh - } -} diff --git a/server/rpc/service.go b/server/rpc/service.go index fd59346..1b86ef1 100644 --- a/server/rpc/service.go +++ b/server/rpc/service.go @@ -109,8 +109,12 @@ func (s *Service) Kill(args string, killed *string) error { // the memory db func (s *Service) KillAll(args NoArgs, killed *string) error { s.logger.Info("Killing all spawned processes") - keys := s.db.Keys() - var err error + + var ( + keys = s.db.Keys() + err error + ) + for _, key := range *keys { proc, err := s.db.Get(key) if err != nil { @@ -121,7 +125,7 @@ func (s *Service) KillAll(args NoArgs, killed *string) error { s.db.Delete(proc.Id) } } - s.mq.Empty() + return err }