Fix heap package usage

Adds tests to ensure order of submitted
tasks with the same priority.
This commit is contained in:
Alex Vanin 2022-05-29 15:40:18 +03:00
parent e476e8570a
commit 5d503a8465
3 changed files with 49 additions and 33 deletions

13
pool.go
View file

@ -1,6 +1,7 @@
package priopool package priopool
import ( import (
"container/heap"
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
@ -48,9 +49,9 @@ func New(poolCapacity, queueCapacity int) (*PriorityPool, error) {
switch { switch {
case queueCapacity >= 0: case queueCapacity >= 0:
queue = make(priorityQueue, 0, queueCapacity) queue.tasks = make([]*priorityQueueTask, 0, queueCapacity)
case queueCapacity < 0: case queueCapacity < 0:
queue = make(priorityQueue, 0, defaultQueueCapacity) queue.tasks = make([]*priorityQueueTask, 0, defaultQueueCapacity)
} }
return &PriorityPool{ return &PriorityPool{
@ -81,7 +82,7 @@ func (p *PriorityPool) Submit(priority uint32, task func()) error {
p.mu.Unlock() p.mu.Unlock()
return return
} }
queueF := p.queue.Pop() queueF := heap.Pop(&p.queue)
p.mu.Unlock() p.mu.Unlock()
queueF.(*priorityQueueTask).value() queueF.(*priorityQueueTask).value()
@ -95,15 +96,13 @@ func (p *PriorityPool) Submit(priority uint32, task func()) error {
return fmt.Errorf("pool submit: %w", err) return fmt.Errorf("pool submit: %w", err)
} }
ln := p.queue.Len() if p.limit >= 0 && p.queue.Len() >= p.limit {
if p.limit >= 0 && ln >= p.limit {
return ErrQueueOverload return ErrQueueOverload
} }
p.queue.Push(&priorityQueueTask{ heap.Push(&p.queue, &priorityQueueTask{
value: task, value: task,
priority: int(priority), priority: int(priority),
index: ln,
}) })
return nil return nil

View file

@ -1,6 +1,7 @@
package priopool_test package priopool_test
import ( import (
"fmt"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -122,6 +123,30 @@ func TestPriorityPool_Submit(t *testing.T) {
wg.Wait() wg.Wait()
}) })
t.Run("non-priority order", func(t *testing.T) {
const n = 5
p, err := priopool.New(1, -1)
require.NoError(t, err)
wg := new(sync.WaitGroup)
wg.Add(n)
result := new(syncList)
for i := 0; i < n; i++ {
id := i
err = p.Submit(lowPriority, taskGenerator(id, result, wg))
require.NoError(t, err)
}
wg.Wait()
fmt.Println(result.list)
for i := 0; i < n; i++ {
require.Equal(t, i, result.list[i])
}
})
} }
func taskGenerator(ind int, output *syncList, wg *sync.WaitGroup) func() { func taskGenerator(ind int, output *syncList, wg *sync.WaitGroup) func() {

View file

@ -4,49 +4,41 @@ package priopool
// Priority queue itself is not thread safe. // Priority queue itself is not thread safe.
// See https://cs.opensource.google/go/go/+/refs/tags/go1.17.2:src/container/heap/example_pq_test.go // See https://cs.opensource.google/go/go/+/refs/tags/go1.17.2:src/container/heap/example_pq_test.go
import (
"container/heap"
)
type priorityQueueTask struct { type priorityQueueTask struct {
value func() value func()
priority int priority int
index int // the index is needed by update and is maintained by the heap.Interface methods index uint64 // monotonusly increasing index to sort values with same priority
} }
type priorityQueue []*priorityQueueTask type priorityQueue struct {
nextIndex uint64
tasks []*priorityQueueTask
}
func (pq priorityQueue) Len() int { return len(pq) } func (pq priorityQueue) Len() int { return len(pq.tasks) }
func (pq priorityQueue) Less(i, j int) bool { func (pq priorityQueue) Less(i, j int) bool {
return pq[i].priority > pq[j].priority if pq.tasks[i].priority == pq.tasks[j].priority {
return pq.tasks[i].index < pq.tasks[j].index
}
return pq.tasks[i].priority > pq.tasks[j].priority
} }
func (pq priorityQueue) Swap(i, j int) { func (pq priorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i] pq.tasks[i], pq.tasks[j] = pq.tasks[j], pq.tasks[i]
pq[i].index = i
pq[j].index = j
} }
func (pq *priorityQueue) Push(x interface{}) { func (pq *priorityQueue) Push(x interface{}) {
n := len(*pq)
item := x.(*priorityQueueTask) item := x.(*priorityQueueTask)
item.index = n item.index = pq.nextIndex
*pq = append(*pq, item) pq.nextIndex++
pq.tasks = append(pq.tasks, item)
} }
func (pq *priorityQueue) Pop() interface{} { func (pq *priorityQueue) Pop() interface{} {
old := *pq n := len(pq.tasks)
n := len(old) item := pq.tasks[n-1]
item := old[n-1] pq.tasks[n-1] = nil
old[n-1] = nil pq.tasks = pq.tasks[0 : n-1]
item.index = -1
*pq = old[0 : n-1]
return item return item
} }
func (pq *priorityQueue) update(item *priorityQueueTask, value func(), priority int) {
item.value = value
item.priority = priority
heap.Fix(pq, item.index)
}