Skip to content

Commit 8427ccd

Browse files
authored
Make result order deterministic (#126)
This makes the order of results in a `Result.*Pool` deterministic so that the order of the result slice corresponds with the order of tasks submitted. As an example of why this would be useful, it makes it easy to rewrite `iter.Map` in terms of `ResultPool`. Additionally, it's a generally nice and intuitive property to be able to match the index of the result slice with the index of the input slice.
1 parent 4afefce commit 8427ccd

6 files changed

Lines changed: 85 additions & 27 deletions

pool/result_context_pool.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@ type ResultContextPool[T any] struct {
2020
// Go submits a task to the pool. If all goroutines in the pool
2121
// are busy, a call to Go() will block until the task can be started.
2222
func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) {
23+
idx := p.agg.nextIndex()
2324
p.contextPool.Go(func(ctx context.Context) error {
2425
res, err := f(ctx)
25-
if err == nil || p.collectErrored {
26-
p.agg.add(res)
27-
}
26+
p.agg.save(idx, res, err != nil)
2827
return err
2928
})
3029
}
@@ -33,7 +32,7 @@ func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) {
3332
// returns an error if any of the tasks errored.
3433
func (p *ResultContextPool[T]) Wait() ([]T, error) {
3534
err := p.contextPool.Wait()
36-
return p.agg.results, err
35+
return p.agg.collect(p.collectErrored), err
3736
}
3837

3938
// WithCollectErrored configures the pool to still collect the result of a task

pool/result_context_pool_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"sort"
87
"strconv"
98
"sync/atomic"
109
"testing"
@@ -223,7 +222,6 @@ func TestResultContextPool(t *testing.T) {
223222
})
224223
}
225224
res, err := g.Wait()
226-
sort.Ints(res)
227225
require.Equal(t, expected, res)
228226
require.NoError(t, err)
229227
require.Equal(t, int64(0), currentConcurrent.Load())

pool/result_error_pool.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ import (
88
// type and an error. Tasks are executed in the pool with Go(), then the
99
// results of the tasks are returned by Wait().
1010
//
11-
// The order of the results is not guaranteed to be the same as the order the
12-
// tasks were submitted. If your use case requires consistent ordering,
13-
// consider using the `stream` package or `Map` from the `iter` package.
11+
// The order of the results is guaranteed to be the same as the order the
12+
// tasks were submitted.
1413
//
1514
// The configuration methods (With*) will panic if they are used after calling
1615
// Go() for the first time.
@@ -23,11 +22,10 @@ type ResultErrorPool[T any] struct {
2322
// Go submits a task to the pool. If all goroutines in the pool
2423
// are busy, a call to Go() will block until the task can be started.
2524
func (p *ResultErrorPool[T]) Go(f func() (T, error)) {
25+
idx := p.agg.nextIndex()
2626
p.errorPool.Go(func() error {
2727
res, err := f()
28-
if err == nil || p.collectErrored {
29-
p.agg.add(res)
30-
}
28+
p.agg.save(idx, res, err != nil)
3129
return err
3230
})
3331
}
@@ -36,7 +34,7 @@ func (p *ResultErrorPool[T]) Go(f func() (T, error)) {
3634
// returning the results and any errors from tasks.
3735
func (p *ResultErrorPool[T]) Wait() ([]T, error) {
3836
err := p.errorPool.Wait()
39-
return p.agg.results, err
37+
return p.agg.collect(p.collectErrored), err
4038
}
4139

4240
// WithCollectErrored configures the pool to still collect the result of a task

pool/result_error_pool_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
"github.com/stretchr/testify/require"
1414
)
1515

16-
func TestResultErrorGroup(t *testing.T) {
16+
func TestResultErrorPool(t *testing.T) {
1717
t.Parallel()
1818

1919
err1 := errors.New("err1")

pool/result_pool.go

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package pool
22

33
import (
44
"context"
5+
"sort"
56
"sync"
67
)
78

@@ -19,9 +20,8 @@ func NewWithResults[T any]() *ResultPool[T] {
1920
// Tasks are executed in the pool with Go(), then the results of the tasks are
2021
// returned by Wait().
2122
//
22-
// The order of the results is not guaranteed to be the same as the order the
23-
// tasks were submitted. If your use case requires consistent ordering,
24-
// consider using the `stream` package or `Map` from the `iter` package.
23+
// The order of the results is guaranteed to be the same as the order the
24+
// tasks were submitted.
2525
type ResultPool[T any] struct {
2626
pool Pool
2727
agg resultAggregator[T]
@@ -30,16 +30,17 @@ type ResultPool[T any] struct {
3030
// Go submits a task to the pool. If all goroutines in the pool
3131
// are busy, a call to Go() will block until the task can be started.
3232
func (p *ResultPool[T]) Go(f func() T) {
33+
idx := p.agg.nextIndex()
3334
p.pool.Go(func() {
34-
p.agg.add(f())
35+
p.agg.save(idx, f(), false)
3536
})
3637
}
3738

3839
// Wait cleans up all spawned goroutines, propagating any panics, and returning
3940
// a slice of results from tasks that did not panic.
4041
func (p *ResultPool[T]) Wait() []T {
4142
p.pool.Wait()
42-
return p.agg.results
43+
return p.agg.collect(true)
4344
}
4445

4546
// MaxGoroutines returns the maximum size of the pool.
@@ -83,11 +84,57 @@ func (p *ResultPool[T]) panicIfInitialized() {
8384
// goroutines. The zero value is valid and ready to use.
8485
type resultAggregator[T any] struct {
8586
mu sync.Mutex
87+
len int
8688
results []T
89+
errored []int
8790
}
8891

89-
func (r *resultAggregator[T]) add(res T) {
92+
// nextIndex reserves a slot for a result. The returned value should be passed
93+
// to save() when adding a result to the aggregator.
94+
func (r *resultAggregator[T]) nextIndex() int {
9095
r.mu.Lock()
91-
r.results = append(r.results, res)
92-
r.mu.Unlock()
96+
defer r.mu.Unlock()
97+
98+
nextIdx := r.len
99+
r.len += 1
100+
return nextIdx
101+
}
102+
103+
func (r *resultAggregator[T]) save(i int, res T, errored bool) {
104+
r.mu.Lock()
105+
defer r.mu.Unlock()
106+
107+
if i >= len(r.results) {
108+
old := r.results
109+
r.results = make([]T, r.len)
110+
copy(r.results, old)
111+
}
112+
113+
r.results[i] = res
114+
115+
if errored {
116+
r.errored = append(r.errored, i)
117+
}
118+
}
119+
120+
// collect returns the set of aggregated results.
121+
func (r *resultAggregator[T]) collect(collectErrored bool) []T {
122+
if !r.mu.TryLock() {
123+
panic("collect should not be called until all goroutines have exited")
124+
}
125+
126+
if collectErrored || len(r.errored) == 0 {
127+
return r.results
128+
}
129+
130+
filtered := r.results[:0]
131+
sort.Ints(r.errored)
132+
for i, e := range r.errored {
133+
if i == 0 {
134+
filtered = append(filtered, r.results[:e]...)
135+
} else {
136+
filtered = append(filtered, r.results[r.errored[i-1]+1:e]...)
137+
}
138+
}
139+
return filtered
93140
}

pool/result_pool_test.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package pool_test
22

33
import (
44
"fmt"
5-
"sort"
5+
"math/rand"
66
"strconv"
77
"sync/atomic"
88
"testing"
@@ -22,8 +22,6 @@ func ExampleResultPool() {
2222
})
2323
}
2424
res := p.Wait()
25-
// Result order is nondeterministic, so sort them first
26-
sort.Ints(res)
2725
fmt.Println(res)
2826

2927
// Output:
@@ -62,10 +60,29 @@ func TestResultGroup(t *testing.T) {
6260
})
6361
}
6462
res := g.Wait()
65-
sort.Ints(res)
6663
require.Equal(t, expected, res)
6764
})
6865

66+
t.Run("deterministic order", func(t *testing.T) {
67+
t.Parallel()
68+
p := pool.NewWithResults[int]()
69+
results := make([]int, 100)
70+
for i := 0; i < 100; i++ {
71+
results[i] = i
72+
}
73+
for _, result := range results {
74+
result := result
75+
p.Go(func() int {
76+
// Add a random sleep to make it exceedingly unlikely that the
77+
// results are returned in the order they are submitted.
78+
time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond)
79+
return result
80+
})
81+
}
82+
got := p.Wait()
83+
require.Equal(t, results, got)
84+
})
85+
6986
t.Run("limit", func(t *testing.T) {
7087
t.Parallel()
7188
for _, maxGoroutines := range []int{1, 10, 100} {
@@ -90,7 +107,6 @@ func TestResultGroup(t *testing.T) {
90107
})
91108
}
92109
res := g.Wait()
93-
sort.Ints(res)
94110
require.Equal(t, expected, res)
95111
require.Equal(t, int64(0), errCount.Load())
96112
require.Equal(t, int64(0), currentConcurrent.Load())

0 commit comments

Comments
 (0)