@@ -2,6 +2,7 @@ package pool
22
33import (
44 "context"
5+ "sort"
56 "sync"
67)
78
@@ -19,9 +20,8 @@ func NewWithResults[T any]() *ResultPool[T] {
1920// Tasks are executed in the pool with Go(), then the results of the tasks are
2021// returned by Wait().
2122//
22- // The order of the results is not guaranteed to be the same as the order the
23- // tasks were submitted. If your use case requires consistent ordering,
24- // consider using the `stream` package or `Map` from the `iter` package.
23+ // The order of the results is guaranteed to be the same as the order the
24+ // tasks were submitted.
2525type ResultPool [T any ] struct {
2626 pool Pool
2727 agg resultAggregator [T ]
@@ -30,16 +30,17 @@ type ResultPool[T any] struct {
3030// Go submits a task to the pool. If all goroutines in the pool
3131// are busy, a call to Go() will block until the task can be started.
3232func (p * ResultPool [T ]) Go (f func () T ) {
33+ idx := p .agg .nextIndex ()
3334 p .pool .Go (func () {
34- p .agg .add ( f ())
35+ p .agg .save ( idx , f (), false )
3536 })
3637}
3738
3839// Wait cleans up all spawned goroutines, propagating any panics, and returning
3940// a slice of results from tasks that did not panic.
4041func (p * ResultPool [T ]) Wait () []T {
4142 p .pool .Wait ()
42- return p .agg .results
43+ return p .agg .collect ( true )
4344}
4445
4546// MaxGoroutines returns the maximum size of the pool.
@@ -83,11 +84,57 @@ func (p *ResultPool[T]) panicIfInitialized() {
8384// goroutines. The zero value is valid and ready to use.
8485type resultAggregator [T any ] struct {
8586 mu sync.Mutex
87+ len int
8688 results []T
89+ errored []int
8790}
8891
89- func (r * resultAggregator [T ]) add (res T ) {
92+ // nextIndex reserves a slot for a result. The returned value should be passed
93+ // to save() when adding a result to the aggregator.
94+ func (r * resultAggregator [T ]) nextIndex () int {
9095 r .mu .Lock ()
91- r .results = append (r .results , res )
92- r .mu .Unlock ()
96+ defer r .mu .Unlock ()
97+
98+ nextIdx := r .len
99+ r .len += 1
100+ return nextIdx
101+ }
102+
103+ func (r * resultAggregator [T ]) save (i int , res T , errored bool ) {
104+ r .mu .Lock ()
105+ defer r .mu .Unlock ()
106+
107+ if i >= len (r .results ) {
108+ old := r .results
109+ r .results = make ([]T , r .len )
110+ copy (r .results , old )
111+ }
112+
113+ r .results [i ] = res
114+
115+ if errored {
116+ r .errored = append (r .errored , i )
117+ }
118+ }
119+
120+ // collect returns the set of aggregated results.
121+ func (r * resultAggregator [T ]) collect (collectErrored bool ) []T {
122+ if ! r .mu .TryLock () {
123+ panic ("collect should not be called until all goroutines have exited" )
124+ }
125+
126+ if collectErrored || len (r .errored ) == 0 {
127+ return r .results
128+ }
129+
130+ filtered := r .results [:0 ]
131+ sort .Ints (r .errored )
132+ for i , e := range r .errored {
133+ if i == 0 {
134+ filtered = append (filtered , r .results [:e ]... )
135+ } else {
136+ filtered = append (filtered , r .results [r .errored [i - 1 ]+ 1 :e ]... )
137+ }
138+ }
139+ return filtered
93140}
0 commit comments