From 5d503a846574858a42853f32542bcc9020ecfd5a Mon Sep 17 00:00:00 2001 From: Alex Vanin Date: Sun, 29 May 2022 15:40:18 +0300 Subject: [PATCH] Fix heap package usage Adds tests to ensure order of submitted tasks with the same priority. --- pool.go | 13 ++++++------- pool_test.go | 25 +++++++++++++++++++++++++ queue.go | 44 ++++++++++++++++++-------------------------- 3 files changed, 49 insertions(+), 33 deletions(-) diff --git a/pool.go b/pool.go index aa14543..f8390b7 100644 --- a/pool.go +++ b/pool.go @@ -1,6 +1,7 @@ package priopool import ( + "container/heap" "errors" "fmt" "sync" @@ -48,9 +49,9 @@ func New(poolCapacity, queueCapacity int) (*PriorityPool, error) { switch { case queueCapacity >= 0: - queue = make(priorityQueue, 0, queueCapacity) + queue.tasks = make([]*priorityQueueTask, 0, queueCapacity) case queueCapacity < 0: - queue = make(priorityQueue, 0, defaultQueueCapacity) + queue.tasks = make([]*priorityQueueTask, 0, defaultQueueCapacity) } return &PriorityPool{ @@ -81,7 +82,7 @@ func (p *PriorityPool) Submit(priority uint32, task func()) error { p.mu.Unlock() return } - queueF := p.queue.Pop() + queueF := heap.Pop(&p.queue) p.mu.Unlock() queueF.(*priorityQueueTask).value() @@ -95,15 +96,13 @@ func (p *PriorityPool) Submit(priority uint32, task func()) error { return fmt.Errorf("pool submit: %w", err) } - ln := p.queue.Len() - if p.limit >= 0 && ln >= p.limit { + if p.limit >= 0 && p.queue.Len() >= p.limit { return ErrQueueOverload } - p.queue.Push(&priorityQueueTask{ + heap.Push(&p.queue, &priorityQueueTask{ value: task, priority: int(priority), - index: ln, }) return nil diff --git a/pool_test.go b/pool_test.go index b9a55b9..dc4dea8 100644 --- a/pool_test.go +++ b/pool_test.go @@ -1,6 +1,7 @@ package priopool_test import ( + "fmt" "sync" "testing" "time" @@ -122,6 +123,30 @@ func TestPriorityPool_Submit(t *testing.T) { wg.Wait() }) + + t.Run("non-priority order", func(t *testing.T) { + const n = 5 + + p, err := priopool.New(1, -1) + require.NoError(t, err) + + wg := new(sync.WaitGroup) + wg.Add(n) + result := new(syncList) + + for i := 0; i < n; i++ { + id := i + err = p.Submit(lowPriority, taskGenerator(id, result, wg)) + require.NoError(t, err) + } + + wg.Wait() + + fmt.Println(result.list) + for i := 0; i < n; i++ { + require.Equal(t, i, result.list[i]) + } + }) } func taskGenerator(ind int, output *syncList, wg *sync.WaitGroup) func() { diff --git a/queue.go b/queue.go index 5d137ee..501fbe5 100644 --- a/queue.go +++ b/queue.go @@ -4,49 +4,41 @@ package priopool // Priority queue itself is not thread safe. // See https://cs.opensource.google/go/go/+/refs/tags/go1.17.2:src/container/heap/example_pq_test.go -import ( - "container/heap" -) - type priorityQueueTask struct { value func() priority int - index int // the index is needed by update and is maintained by the heap.Interface methods + index uint64 // monotonusly increasing index to sort values with same priority } -type priorityQueue []*priorityQueueTask +type priorityQueue struct { + nextIndex uint64 + tasks []*priorityQueueTask +} -func (pq priorityQueue) Len() int { return len(pq) } +func (pq priorityQueue) Len() int { return len(pq.tasks) } func (pq priorityQueue) Less(i, j int) bool { - return pq[i].priority > pq[j].priority + if pq.tasks[i].priority == pq.tasks[j].priority { + return pq.tasks[i].index < pq.tasks[j].index + } + return pq.tasks[i].priority > pq.tasks[j].priority } func (pq priorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j + pq.tasks[i], pq.tasks[j] = pq.tasks[j], pq.tasks[i] } func (pq *priorityQueue) Push(x interface{}) { - n := len(*pq) item := x.(*priorityQueueTask) - item.index = n - *pq = append(*pq, item) + item.index = pq.nextIndex + pq.nextIndex++ + pq.tasks = append(pq.tasks, item) } func (pq *priorityQueue) Pop() interface{} { - old := *pq - n := len(old) - item := old[n-1] - old[n-1] = nil - item.index = -1 - *pq = old[0 : n-1] + n := len(pq.tasks) + item := pq.tasks[n-1] + pq.tasks[n-1] = nil + pq.tasks = pq.tasks[0 : n-1] return item } - -func (pq *priorityQueue) update(item *priorityQueueTask, value func(), priority int) { - item.value = value - item.priority = priority - heap.Fix(pq, item.index) -}