diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 0000000..962cfb0 --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,80 @@ +name: Benchmark + +on: + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + go-bench: + strategy: + matrix: + go-version: [ '1.20', 'stable' ] + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + fetch-depth: 0 # to be able to retrieve the last commit in main + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + + - name: Run benchmark and store the output to a file + run: | + set -o pipefail + make bench | tee ${{ github.sha }}_bench_output.txt + + - name: Get CPU information + uses: kenchan0130/actions-system-info@v1.2.1 + id: system-info + + - name: Get Main branch SHA + id: get-main-branch-sha + run: | + SHA=$(git rev-parse origin/main) + echo "sha=$SHA" >> $GITHUB_OUTPUT + + - name: Get benchmark JSON from main branch + id: cache + uses: actions/cache/restore@v3 + with: + path: ./cache/benchmark-data.json + key: ${{ steps.get-main-branch-sha.outputs.sha }}-${{ runner.os }}-${{ steps.system-info.outputs.cpu-model }}-go-benchmark + + - name: Compare benchmarks with Main + uses: benchmark-action/github-action-benchmark@v1 + if: steps.cache.outputs.cache-hit == 'true' + with: + # What benchmark tool the output.txt came from + tool: 'go' + # Where the output from the benchmark tool is stored + output-file-path: ${{ github.sha }}_bench_output.txt + # Where the benchmarks in main are (to compare) + external-data-json-path: ./cache/benchmark-data.json + # Do not save the data + save-data-file: false + # Workflow will fail when an alert happens + fail-on-alert: true + github-token: ${{ secrets.GITHUB_TOKEN }} + # Enable Job Summary for PRs + summary-always: true + + - name: Run benchmarks but don't compare to Main branch + uses: benchmark-action/github-action-benchmark@v1 + if: steps.cache.outputs.cache-hit != 'true' + with: + # What benchmark tool the output.txt came from + tool: 'go' + # Where the output from the benchmark tool is stored + output-file-path: ${{ github.sha }}_bench_output.txt + # Write benchmarks to this file, do not publish to GitHub Pages + save-data-file: false + external-data-json-path: ./cache/benchmark-data.json + # Enable Job Summary for PRs + summary-always: true diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 0b1cc1f..f190c54 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -14,7 +14,7 @@ jobs: build: strategy: matrix: - go-version: ['1.19', 'stable'] + go-version: ['1.20', 'stable'] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..87e68ae --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,52 @@ +name: Main +on: + push: + branches: + - main + +permissions: + contents: read + +jobs: + go-bench: + strategy: + matrix: + go-version: [ '1.20', 'stable' ] + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + + - name: Run benchmark and store the output to a file + run: | + set -o pipefail + make bench | tee bench_output.txt + + - name: Get benchmark as JSON + uses: benchmark-action/github-action-benchmark@v1 + with: + # What benchmark tool the output.txt came from + tool: 'go' + # Where the output from the benchmark tool is stored + output-file-path: bench_output.txt + # Write benchmarks to this file + external-data-json-path: ./cache/benchmark-data.json + # Workflow will fail when an alert happens + fail-on-alert: true + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Get CPU information + uses: kenchan0130/actions-system-info@v1.2.1 + id: system-info + + - name: Save benchmark JSON to cache + uses: actions/cache/save@v3 + with: + path: ./cache/benchmark-data.json + # Save with commit hash to avoid "cache already exists" + # Save with OS & CPU info to prevent comparing against results from different CPUs + key: ${{ github.sha }}-${{ runner.os }}-${{ steps.system-info.outputs.cpu-model }}-go-benchmark diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3e0720a --- /dev/null +++ b/Makefile @@ -0,0 +1,24 @@ +.DEFAULT_GOAL := help + +GO_BIN ?= $(shell go env GOPATH)/bin + +.PHONY: help +help: + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +$(GO_BIN)/golangci-lint: + @echo "==> Installing golangci-lint within "${GO_BIN}"" + @go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@latest + +.PHONY: lint +lint: $(GO_BIN)/golangci-lint ## Run linting on Go files + @echo "==> Linting Go source files" + @golangci-lint run -v --fix -c .golangci.yml ./... + +.PHONY: test +test: ## Run tests + go test -race -v ./... -coverprofile ./coverage.txt + +.PHONY: bench +bench: ## Run benchmarks. See https://pkg.go.dev/cmd/go#hdr-Testing_flags + go test ./... -bench . -benchtime 5s -timeout 0 -run=XXX -cpu 1 -benchmem diff --git a/go.mod b/go.mod index e06798f..aa49e92 100644 --- a/go.mod +++ b/go.mod @@ -1,18 +1,14 @@ module github.com/sourcegraph/conc -go 1.19 +go 1.20 -require ( - github.com/stretchr/testify v1.8.1 - go.uber.org/multierr v1.9.0 -) +require github.com/stretchr/testify v1.8.1 require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/pretty v0.3.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect - go.uber.org/atomic v1.7.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 2eaf607..8254f47 100644 --- a/go.sum +++ b/go.sum @@ -17,15 +17,10 @@ github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/f github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= -go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= diff --git a/internal/multierror/multierror_go119.go b/internal/multierror/multierror_go119.go deleted file mode 100644 index 7087e32..0000000 --- a/internal/multierror/multierror_go119.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !go1.20 -// +build !go1.20 - -package multierror - -import "go.uber.org/multierr" - -var ( - Join = multierr.Combine -) diff --git a/internal/multierror/multierror_go120.go b/internal/multierror/multierror_go120.go deleted file mode 100644 index 39cff82..0000000 --- a/internal/multierror/multierror_go120.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build go1.20 -// +build go1.20 - -package multierror - -import "errors" - -var ( - Join = errors.Join -) diff --git a/iter/export_test.go b/iter/export_test.go new file mode 100644 index 0000000..c406d06 --- /dev/null +++ b/iter/export_test.go @@ -0,0 +1,3 @@ +package iter + +var DefaultMaxGoroutines = defaultMaxGoroutines diff --git a/iter/iter_test.go b/iter/iter_test.go index d65af2b..48fc8bb 100644 --- a/iter/iter_test.go +++ b/iter/iter_test.go @@ -1,4 +1,4 @@ -package iter +package iter_test import ( "fmt" @@ -6,12 +6,14 @@ import ( "sync/atomic" "testing" + "github.com/sourcegraph/conc/iter" + "github.com/stretchr/testify/require" ) func ExampleIterator() { input := []int{1, 2, 3, 4} - iterator := Iterator[int]{ + iterator := iter.Iterator[int]{ MaxGoroutines: len(input) / 2, } @@ -32,7 +34,7 @@ func TestIterator(t *testing.T) { t.Run("safe for reuse", func(t *testing.T) { t.Parallel() - iterator := Iterator[int]{MaxGoroutines: 999} + iterator := iter.Iterator[int]{MaxGoroutines: 999} // iter.Concurrency > numInput case that updates iter.Concurrency iterator.ForEachIdx([]int{1, 2, 3}, func(i int, t *int) {}) @@ -43,12 +45,12 @@ func TestIterator(t *testing.T) { t.Run("allows more than defaultMaxGoroutines() concurrent tasks", func(t *testing.T) { t.Parallel() - wantConcurrency := 2 * defaultMaxGoroutines() + wantConcurrency := 2 * iter.DefaultMaxGoroutines() maxConcurrencyHit := make(chan struct{}) tasks := make([]int, wantConcurrency) - iterator := Iterator[int]{MaxGoroutines: wantConcurrency} + iterator := iter.Iterator[int]{MaxGoroutines: wantConcurrency} var concurrentTasks atomic.Int64 iterator.ForEach(tasks, func(t *int) { @@ -77,7 +79,7 @@ func TestForEachIdx(t *testing.T) { t.Parallel() f := func() { ints := []int{} - ForEachIdx(ints, func(i int, val *int) { + iter.ForEachIdx(ints, func(i int, val *int) { panic("this should never be called") }) } @@ -88,7 +90,7 @@ func TestForEachIdx(t *testing.T) { t.Parallel() f := func() { ints := []int{1} - ForEachIdx(ints, func(i int, val *int) { + iter.ForEachIdx(ints, func(i int, val *int) { panic("super bad thing happened") }) } @@ -98,7 +100,7 @@ func TestForEachIdx(t *testing.T) { t.Run("mutating inputs is fine", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - ForEachIdx(ints, func(i int, val *int) { + iter.ForEachIdx(ints, func(i int, val *int) { *val += 1 }) require.Equal(t, []int{2, 3, 4, 5, 6}, ints) @@ -107,7 +109,7 @@ func TestForEachIdx(t *testing.T) { t.Run("huge inputs", func(t *testing.T) { t.Parallel() ints := make([]int, 10000) - ForEachIdx(ints, func(i int, val *int) { + iter.ForEachIdx(ints, func(i int, val *int) { *val = i }) expected := make([]int, 10000) @@ -125,7 +127,7 @@ func TestForEach(t *testing.T) { t.Parallel() f := func() { ints := []int{} - ForEach(ints, func(val *int) { + iter.ForEach(ints, func(val *int) { panic("this should never be called") }) } @@ -136,7 +138,7 @@ func TestForEach(t *testing.T) { t.Parallel() f := func() { ints := []int{1} - ForEach(ints, func(val *int) { + iter.ForEach(ints, func(val *int) { panic("super bad thing happened") }) } @@ -146,7 +148,7 @@ func TestForEach(t *testing.T) { t.Run("mutating inputs is fine", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - ForEach(ints, func(val *int) { + iter.ForEach(ints, func(val *int) { *val += 1 }) require.Equal(t, []int{2, 3, 4, 5, 6}, ints) @@ -155,7 +157,7 @@ func TestForEach(t *testing.T) { t.Run("huge inputs", func(t *testing.T) { t.Parallel() ints := make([]int, 10000) - ForEach(ints, func(val *int) { + iter.ForEach(ints, func(val *int) { *val = 1 }) expected := make([]int, 10000) @@ -171,7 +173,7 @@ func BenchmarkForEach(b *testing.B) { b.Run(strconv.Itoa(count), func(b *testing.B) { ints := make([]int, count) for i := 0; i < b.N; i++ { - ForEach(ints, func(i *int) { + iter.ForEach(ints, func(i *int) { *i = 0 }) } diff --git a/iter/map.go b/iter/map.go index efbe6bf..af8c3b2 100644 --- a/iter/map.go +++ b/iter/map.go @@ -1,9 +1,8 @@ package iter import ( + "errors" "sync" - - "github.com/sourcegraph/conc/internal/multierror" ) // Mapper is an Iterator with a result type R. It can be used to configure @@ -49,17 +48,16 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { var ( res = make([]R, len(input)) errMux sync.Mutex - errs error + errs []error ) Iterator[T](m).ForEachIdx(input, func(i int, t *T) { var err error res[i], err = f(t) if err != nil { errMux.Lock() - // TODO: use stdlib errors once multierrors land in go 1.20 - errs = multierror.Join(errs, err) + errs = append(errs, err) errMux.Unlock() } }) - return res, errs + return res, errors.Join(errs...) } diff --git a/iter/map_test.go b/iter/map_test.go index 28be912..5749e9a 100644 --- a/iter/map_test.go +++ b/iter/map_test.go @@ -1,16 +1,18 @@ -package iter +package iter_test import ( "errors" "fmt" "testing" + "github.com/sourcegraph/conc/iter" + "github.com/stretchr/testify/require" ) func ExampleMapper() { input := []int{1, 2, 3, 4} - mapper := Mapper[int, bool]{ + mapper := iter.Mapper[int, bool]{ MaxGoroutines: len(input) / 2, } @@ -27,7 +29,7 @@ func TestMap(t *testing.T) { t.Parallel() f := func() { ints := []int{} - Map(ints, func(val *int) int { + iter.Map(ints, func(val *int) int { panic("this should never be called") }) } @@ -38,7 +40,7 @@ func TestMap(t *testing.T) { t.Parallel() f := func() { ints := []int{1} - Map(ints, func(val *int) int { + iter.Map(ints, func(val *int) int { panic("super bad thing happened") }) } @@ -48,7 +50,7 @@ func TestMap(t *testing.T) { t.Run("mutating inputs is fine, though not recommended", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - Map(ints, func(val *int) int { + iter.Map(ints, func(val *int) int { *val += 1 return 0 }) @@ -58,7 +60,7 @@ func TestMap(t *testing.T) { t.Run("basic increment", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res := Map(ints, func(val *int) int { + res := iter.Map(ints, func(val *int) int { return *val + 1 }) require.Equal(t, []int{2, 3, 4, 5, 6}, res) @@ -68,7 +70,7 @@ func TestMap(t *testing.T) { t.Run("huge inputs", func(t *testing.T) { t.Parallel() ints := make([]int, 10000) - res := Map(ints, func(val *int) int { + res := iter.Map(ints, func(val *int) int { return 1 }) expected := make([]int, 10000) @@ -86,7 +88,7 @@ func TestMapErr(t *testing.T) { t.Parallel() f := func() { ints := []int{} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := iter.MapErr(ints, func(val *int) (int, error) { panic("this should never be called") }) require.NoError(t, err) @@ -99,7 +101,7 @@ func TestMapErr(t *testing.T) { t.Parallel() f := func() { ints := []int{1} - _, _ = MapErr(ints, func(val *int) (int, error) { + _, _ = iter.MapErr(ints, func(val *int) (int, error) { panic("super bad thing happened") }) } @@ -109,7 +111,7 @@ func TestMapErr(t *testing.T) { t.Run("mutating inputs is fine, though not recommended", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := iter.MapErr(ints, func(val *int) (int, error) { *val += 1 return 0, nil }) @@ -121,7 +123,7 @@ func TestMapErr(t *testing.T) { t.Run("basic increment", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := iter.MapErr(ints, func(val *int) (int, error) { return *val + 1, nil }) require.NoError(t, err) @@ -135,7 +137,7 @@ func TestMapErr(t *testing.T) { t.Run("error is propagated", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := iter.MapErr(ints, func(val *int) (int, error) { if *val == 3 { return 0, err1 } @@ -149,7 +151,7 @@ func TestMapErr(t *testing.T) { t.Run("multiple errors are propagated", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := iter.MapErr(ints, func(val *int) (int, error) { if *val == 3 { return 0, err1 } @@ -160,6 +162,7 @@ func TestMapErr(t *testing.T) { }) require.ErrorIs(t, err, err1) require.ErrorIs(t, err, err2) + require.ElementsMatch(t, err.(interface{ Unwrap() []error }).Unwrap(), []error{err1, err2}) require.Equal(t, []int{2, 3, 0, 0, 6}, res) require.Equal(t, []int{1, 2, 3, 4, 5}, ints) }) @@ -167,7 +170,7 @@ func TestMapErr(t *testing.T) { t.Run("huge inputs", func(t *testing.T) { t.Parallel() ints := make([]int, 10000) - res := Map(ints, func(val *int) int { + res := iter.Map(ints, func(val *int) int { return 1 }) expected := make([]int, 10000) diff --git a/panics/panics_test.go b/panics/panics_test.go index 15bb32e..3c9ef58 100644 --- a/panics/panics_test.go +++ b/panics/panics_test.go @@ -1,4 +1,4 @@ -package panics +package panics_test import ( "errors" @@ -7,12 +7,14 @@ import ( "sync" "testing" + "github.com/sourcegraph/conc/panics" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func ExampleCatcher() { - var pc Catcher + var pc panics.Catcher i := 0 pc.Try(func() { i += 1 }) pc.Try(func() { panic("abort!") }) @@ -28,7 +30,7 @@ func ExampleCatcher() { } func ExampleCatcher_callers() { - var pc Catcher + var pc panics.Catcher pc.Try(func() { panic("mayday!") }) recovered := pc.Recovered() @@ -50,9 +52,9 @@ func ExampleCatcher_callers() { // Output: // github.com/sourcegraph/conc/panics.(*Catcher).tryRecover // runtime.gopanic - // github.com/sourcegraph/conc/panics.ExampleCatcher_callers.func1 + // github.com/sourcegraph/conc/panics_test.ExampleCatcher_callers.func1 // github.com/sourcegraph/conc/panics.(*Catcher).Try - // github.com/sourcegraph/conc/panics.ExampleCatcher_callers + // github.com/sourcegraph/conc/panics_test.ExampleCatcher_callers // testing.runExample // testing.runExamples // testing.(*M).Run @@ -63,7 +65,7 @@ func ExampleCatcher_callers() { func ExampleCatcher_error() { helper := func() error { - var pc Catcher + var pc panics.Catcher pc.Try(func() { panic(errors.New("error")) }) return pc.Recovered().AsError() } @@ -88,7 +90,7 @@ func TestCatcher(t *testing.T) { t.Run("error", func(t *testing.T) { t.Parallel() - var pc Catcher + var pc panics.Catcher pc.Try(func() { panic(err1) }) recovered := pc.Recovered() require.ErrorIs(t, recovered.AsError(), err1) @@ -101,7 +103,7 @@ func TestCatcher(t *testing.T) { }) t.Run("not error", func(t *testing.T) { - var pc Catcher + var pc panics.Catcher pc.Try(func() { panic("definitely not an error") }) recovered := pc.Recovered() require.NotErrorIs(t, recovered.AsError(), err1) @@ -109,14 +111,14 @@ func TestCatcher(t *testing.T) { }) t.Run("repanic panics", func(t *testing.T) { - var pc Catcher + var pc panics.Catcher pc.Try(func() { panic(err1) }) require.Panics(t, pc.Repanic) }) t.Run("repanic does not panic without child panic", func(t *testing.T) { t.Parallel() - var pc Catcher + var pc panics.Catcher pc.Try(func() { _ = 1 }) require.NotPanics(t, pc.Repanic) }) @@ -124,7 +126,7 @@ func TestCatcher(t *testing.T) { t.Run("is goroutine safe", func(t *testing.T) { t.Parallel() var wg sync.WaitGroup - var pc Catcher + var pc panics.Catcher for i := 0; i < 100; i++ { i := i wg.Add(1) @@ -148,7 +150,7 @@ func TestRecoveredAsError(t *testing.T) { t.Run("as error is nil", func(t *testing.T) { t.Parallel() fn := func() error { - var c Catcher + var c panics.Catcher c.Try(func() {}) return c.Recovered().AsError() } @@ -156,10 +158,10 @@ func TestRecoveredAsError(t *testing.T) { assert.Nil(t, err) }) - t.Run("as error is not nil nil", func(t *testing.T) { + t.Run("as error is not nil", func(t *testing.T) { t.Parallel() fn := func() error { - var c Catcher + var c panics.Catcher c.Try(func() { panic("oh dear!") }) return c.Recovered().AsError() } diff --git a/panics/try_test.go b/panics/try_test.go index 02a41e9..335e8c5 100644 --- a/panics/try_test.go +++ b/panics/try_test.go @@ -1,9 +1,11 @@ -package panics +package panics_test import ( "errors" "testing" + "github.com/sourcegraph/conc/panics" + "github.com/stretchr/testify/require" ) @@ -14,7 +16,7 @@ func TestTry(t *testing.T) { t.Parallel() err := errors.New("SOS") - recovered := Try(func() { panic(err) }) + recovered := panics.Try(func() { panic(err) }) require.ErrorIs(t, recovered.AsError(), err) require.ErrorAs(t, recovered.AsError(), &err) // The exact contents aren't tested because the stacktrace contains local file paths @@ -27,7 +29,7 @@ func TestTry(t *testing.T) { t.Run("no panic", func(t *testing.T) { t.Parallel() - recovered := Try(func() {}) + recovered := panics.Try(func() {}) require.Nil(t, recovered) }) } diff --git a/pool/context_pool.go b/pool/context_pool.go index b2d7f8a..85c34e5 100644 --- a/pool/context_pool.go +++ b/pool/context_pool.go @@ -81,6 +81,16 @@ func (p *ContextPool) WithCancelOnError() *ContextPool { return p } +// WithFailFast is an alias for the combination of WithFirstError and +// WithCancelOnError. By default, the errors from all tasks are returned and +// the pool's context is not canceled until the parent context is canceled. +func (p *ContextPool) WithFailFast() *ContextPool { + p.panicIfInitialized() + p.WithFirstError() + p.WithCancelOnError() + return p +} + // WithMaxGoroutines limits the number of goroutines in a pool. // Defaults to unlimited. Panics if n < 1. func (p *ContextPool) WithMaxGoroutines(n int) *ContextPool { diff --git a/pool/context_pool_test.go b/pool/context_pool_test.go index 1e3f1f0..3f4c1ef 100644 --- a/pool/context_pool_test.go +++ b/pool/context_pool_test.go @@ -1,4 +1,4 @@ -package pool +package pool_test import ( "context" @@ -9,12 +9,14 @@ import ( "testing" "time" + "github.com/sourcegraph/conc/pool" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func ExampleContextPool_WithCancelOnError() { - p := New(). + p := pool.New(). WithMaxGoroutines(4). WithContext(context.Background()). WithCancelOnError() @@ -44,14 +46,14 @@ func TestContextPool(t *testing.T) { t.Run("panics on configuration after init", func(t *testing.T) { t.Run("before wait", func(t *testing.T) { t.Parallel() - g := New().WithContext(context.Background()) + g := pool.New().WithContext(context.Background()) g.Go(func(context.Context) error { return nil }) require.Panics(t, func() { g.WithMaxGoroutines(10) }) }) t.Run("after wait", func(t *testing.T) { t.Parallel() - g := New().WithContext(context.Background()) + g := pool.New().WithContext(context.Background()) g.Go(func(context.Context) error { return nil }) _ = g.Wait() require.Panics(t, func() { g.WithMaxGoroutines(10) }) @@ -63,21 +65,21 @@ func TestContextPool(t *testing.T) { t.Run("wait returns no error if no errors", func(t *testing.T) { t.Parallel() - p := New().WithContext(bgctx) + p := pool.New().WithContext(bgctx) p.Go(func(context.Context) error { return nil }) require.NoError(t, p.Wait()) }) t.Run("wait errors if func returns error", func(t *testing.T) { t.Parallel() - p := New().WithContext(bgctx) + p := pool.New().WithContext(bgctx) p.Go(func(context.Context) error { return err1 }) require.ErrorIs(t, p.Wait(), err1) }) t.Run("wait error is all returned errors", func(t *testing.T) { t.Parallel() - p := New().WithErrors().WithContext(bgctx) + p := pool.New().WithErrors().WithContext(bgctx) p.Go(func(context.Context) error { return err1 }) p.Go(func(context.Context) error { return nil }) p.Go(func(context.Context) error { return err2 }) @@ -93,7 +95,7 @@ func TestContextPool(t *testing.T) { t.Run("canceled", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(bgctx) - p := New().WithContext(ctx) + p := pool.New().WithContext(ctx) p.Go(func(ctx context.Context) error { <-ctx.Done() return ctx.Err() @@ -106,18 +108,32 @@ func TestContextPool(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(bgctx, time.Millisecond) defer cancel() - p := New().WithContext(ctx) + p := pool.New().WithContext(ctx) p.Go(func(ctx context.Context) error { <-ctx.Done() return ctx.Err() }) require.ErrorIs(t, p.Wait(), context.DeadlineExceeded) }) + + t.Run("return before timed out", func(t *testing.T) { + t.Parallel() + p := pool.New().WithContext(context.Background()) + p.Go(func(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(1 * time.Millisecond): + return nil + } + }) + require.NoError(t, p.Wait()) + }) }) t.Run("WithCancelOnError", func(t *testing.T) { t.Parallel() - p := New().WithContext(bgctx).WithCancelOnError() + p := pool.New().WithContext(bgctx).WithCancelOnError() p.Go(func(ctx context.Context) error { <-ctx.Done() return ctx.Err() @@ -132,7 +148,7 @@ func TestContextPool(t *testing.T) { t.Run("no WithCancelOnError", func(t *testing.T) { t.Parallel() - p := New().WithContext(bgctx) + p := pool.New().WithContext(bgctx) p.Go(func(ctx context.Context) error { select { case <-ctx.Done(): @@ -151,7 +167,7 @@ func TestContextPool(t *testing.T) { t.Run("WithFirstError", func(t *testing.T) { t.Parallel() - p := New().WithContext(bgctx).WithFirstError() + p := pool.New().WithContext(bgctx).WithFirstError() sync := make(chan struct{}) p.Go(func(ctx context.Context) error { defer close(sync) @@ -173,9 +189,9 @@ func TestContextPool(t *testing.T) { require.NotErrorIs(t, err, err2) }) - t.Run("WithFirstError and WithCancelOnError", func(t *testing.T) { + t.Run("WithFailFast", func(t *testing.T) { t.Parallel() - p := New().WithContext(bgctx).WithFirstError().WithCancelOnError() + p := pool.New().WithContext(bgctx).WithFailFast() p.Go(func(ctx context.Context) error { return err1 }) @@ -190,7 +206,7 @@ func TestContextPool(t *testing.T) { t.Run("WithCancelOnError and panic", func(t *testing.T) { t.Parallel() - p := New().WithContext(bgctx).WithCancelOnError() + p := pool.New().WithContext(bgctx).WithCancelOnError() var cancelledTasks atomic.Int64 p.Go(func(ctx context.Context) error { <-ctx.Done() @@ -216,7 +232,7 @@ func TestContextPool(t *testing.T) { maxConcurrent := maxConcurrent // copy t.Parallel() - p := New().WithContext(bgctx).WithMaxGoroutines(maxConcurrent) + p := pool.New().WithContext(bgctx).WithMaxGoroutines(maxConcurrent) var currentConcurrent atomic.Int64 for i := 0; i < 100; i++ { diff --git a/pool/error_pool.go b/pool/error_pool.go index 6e5aa99..e1789e6 100644 --- a/pool/error_pool.go +++ b/pool/error_pool.go @@ -2,9 +2,8 @@ package pool import ( "context" + "errors" "sync" - - "github.com/sourcegraph/conc/internal/multierror" ) // ErrorPool is a pool that runs tasks that may return an error. @@ -20,7 +19,7 @@ type ErrorPool struct { onlyFirstError bool mu sync.Mutex - errs error + errs []error } // Go submits a task to the pool. If all goroutines in the pool @@ -35,7 +34,17 @@ func (p *ErrorPool) Go(f func() error) { // returning any errors from tasks. func (p *ErrorPool) Wait() error { p.pool.Wait() - return p.errs + + errs := p.errs + p.errs = nil // reset errs + + if len(errs) == 0 { + return nil + } else if p.onlyFirstError { + return errs[0] + } else { + return errors.Join(errs...) + } } // WithContext converts the pool to a ContextPool for tasks that should @@ -85,13 +94,7 @@ func (p *ErrorPool) panicIfInitialized() { func (p *ErrorPool) addErr(err error) { if err != nil { p.mu.Lock() - if p.onlyFirstError { - if p.errs == nil { - p.errs = err - } - } else { - p.errs = multierror.Join(p.errs, err) - } + p.errs = append(p.errs, err) p.mu.Unlock() } } diff --git a/pool/error_pool_test.go b/pool/error_pool_test.go index 0ab88cb..814f90b 100644 --- a/pool/error_pool_test.go +++ b/pool/error_pool_test.go @@ -1,4 +1,4 @@ -package pool +package pool_test import ( "errors" @@ -8,11 +8,13 @@ import ( "testing" "time" + "github.com/sourcegraph/conc/pool" + "github.com/stretchr/testify/require" ) func ExampleErrorPool() { - p := New().WithErrors() + p := pool.New().WithErrors() for i := 0; i < 3; i++ { i := i p.Go(func() error { @@ -37,14 +39,14 @@ func TestErrorPool(t *testing.T) { t.Run("panics on configuration after init", func(t *testing.T) { t.Run("before wait", func(t *testing.T) { t.Parallel() - g := New().WithErrors() + g := pool.New().WithErrors() g.Go(func() error { return nil }) require.Panics(t, func() { g.WithMaxGoroutines(10) }) }) t.Run("after wait", func(t *testing.T) { t.Parallel() - g := New().WithErrors() + g := pool.New().WithErrors() g.Go(func() error { return nil }) _ = g.Wait() require.Panics(t, func() { g.WithMaxGoroutines(10) }) @@ -53,21 +55,21 @@ func TestErrorPool(t *testing.T) { t.Run("wait returns no error if no errors", func(t *testing.T) { t.Parallel() - g := New().WithErrors() + g := pool.New().WithErrors() g.Go(func() error { return nil }) require.NoError(t, g.Wait()) }) t.Run("wait error if func returns error", func(t *testing.T) { t.Parallel() - g := New().WithErrors() + g := pool.New().WithErrors() g.Go(func() error { return err1 }) require.ErrorIs(t, g.Wait(), err1) }) t.Run("wait error is all returned errors", func(t *testing.T) { t.Parallel() - g := New().WithErrors() + g := pool.New().WithErrors() g.Go(func() error { return err1 }) g.Go(func() error { return nil }) g.Go(func() error { return err2 }) @@ -78,7 +80,7 @@ func TestErrorPool(t *testing.T) { t.Run("propagates panics", func(t *testing.T) { t.Parallel() - g := New().WithErrors() + g := pool.New().WithErrors() for i := 0; i < 10; i++ { i := i g.Go(func() error { @@ -95,7 +97,7 @@ func TestErrorPool(t *testing.T) { t.Parallel() for _, maxGoroutines := range []int{1, 10, 100} { t.Run(strconv.Itoa(maxGoroutines), func(t *testing.T) { - g := New().WithErrors().WithMaxGoroutines(maxGoroutines) + g := pool.New().WithErrors().WithMaxGoroutines(maxGoroutines) var currentConcurrent atomic.Int64 taskCount := maxGoroutines * 10 @@ -115,4 +117,19 @@ func TestErrorPool(t *testing.T) { }) } }) + + t.Run("reuse", func(t *testing.T) { + // Test for https://github.com/sourcegraph/conc/issues/128 + p := pool.New().WithErrors() + + p.Go(func() error { return err1 }) + wait1 := p.Wait() + require.ErrorIs(t, wait1, err1) + + p.Go(func() error { return err2 }) + wait2 := p.Wait() + // On reuse, only the new error should be returned + require.ErrorIs(t, wait2, err2) + require.NotErrorIs(t, wait1, err2) + }) } diff --git a/pool/pool.go b/pool/pool.go index b63eb19..8f4494e 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -47,20 +47,18 @@ func (p *Pool) Go(f func()) { default: // No goroutine was available to handle the task. // Spawn a new one and send it the task. - p.handle.Go(p.worker) - p.tasks <- f + p.handle.Go(func() { + p.worker(f) + }) } } else { select { case p.limiter <- struct{}{}: // If we are below our limit, spawn a new worker rather // than waiting for one to become available. - p.handle.Go(p.worker) - - // We know there is at least one worker running, so wait - // for it to become available. This ensures we never spawn - // more workers than the number of tasks. - p.tasks <- f + p.handle.Go(func() { + p.worker(f) + }) case p.tasks <- f: // A worker is available and has accepted the task. return @@ -76,6 +74,10 @@ func (p *Pool) Wait() { close(p.tasks) + // After Wait() returns, reset the struct so tasks will be reinitialized on + // next use. This better matches the behavior of sync.WaitGroup + defer func() { p.initOnce = sync.Once{} }() + p.handle.Wait() } @@ -145,11 +147,15 @@ func (p *Pool) WithContext(ctx context.Context) *ContextPool { } } -func (p *Pool) worker() { +func (p *Pool) worker(initialFunc func()) { // The only time this matters is if the task panics. // This makes it possible to spin up new workers in that case. defer p.limiter.release() + if initialFunc != nil { + initialFunc() + } + for f := range p.tasks { f() } diff --git a/pool/pool_test.go b/pool/pool_test.go index a6e93da..6791b97 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -1,4 +1,4 @@ -package pool +package pool_test import ( "fmt" @@ -7,11 +7,13 @@ import ( "testing" "time" + "github.com/sourcegraph/conc/pool" + "github.com/stretchr/testify/require" ) func ExamplePool() { - p := New().WithMaxGoroutines(3) + p := pool.New().WithMaxGoroutines(3) for i := 0; i < 5; i++ { p.Go(func() { fmt.Println("conc") @@ -32,7 +34,7 @@ func TestPool(t *testing.T) { t.Run("basic", func(t *testing.T) { t.Parallel() - g := New() + g := pool.New() var completed atomic.Int64 for i := 0; i < 100; i++ { g.Go(func() { @@ -47,14 +49,14 @@ func TestPool(t *testing.T) { t.Run("panics on configuration after init", func(t *testing.T) { t.Run("before wait", func(t *testing.T) { t.Parallel() - g := New() + g := pool.New() g.Go(func() {}) require.Panics(t, func() { g.WithMaxGoroutines(10) }) }) t.Run("after wait", func(t *testing.T) { t.Parallel() - g := New() + g := pool.New() g.Go(func() {}) g.Wait() require.Panics(t, func() { g.WithMaxGoroutines(10) }) @@ -65,7 +67,7 @@ func TestPool(t *testing.T) { t.Parallel() for _, maxConcurrent := range []int{1, 10, 100} { t.Run(strconv.Itoa(maxConcurrent), func(t *testing.T) { - g := New().WithMaxGoroutines(maxConcurrent) + g := pool.New().WithMaxGoroutines(maxConcurrent) var currentConcurrent atomic.Int64 var errCount atomic.Int64 @@ -89,7 +91,7 @@ func TestPool(t *testing.T) { t.Run("propagate panic", func(t *testing.T) { t.Parallel() - g := New() + g := pool.New() for i := 0; i < 10; i++ { i := i g.Go(func() { @@ -103,7 +105,7 @@ func TestPool(t *testing.T) { t.Run("panics do not exhaust goroutines", func(t *testing.T) { t.Parallel() - g := New().WithMaxGoroutines(2) + g := pool.New().WithMaxGoroutines(2) for i := 0; i < 10; i++ { g.Go(func() { panic(42) @@ -114,27 +116,47 @@ func TestPool(t *testing.T) { t.Run("panics on invalid WithMaxGoroutines", func(t *testing.T) { t.Parallel() - require.Panics(t, func() { New().WithMaxGoroutines(0) }) + require.Panics(t, func() { pool.New().WithMaxGoroutines(0) }) }) t.Run("returns correct MaxGoroutines", func(t *testing.T) { t.Parallel() - p := New().WithMaxGoroutines(42) + p := pool.New().WithMaxGoroutines(42) require.Equal(t, 42, p.MaxGoroutines()) }) + + t.Run("is reusable", func(t *testing.T) { + t.Parallel() + var count atomic.Int64 + p := pool.New() + for i := 0; i < 10; i++ { + p.Go(func() { + count.Add(1) + }) + } + p.Wait() + require.Equal(t, int64(10), count.Load()) + for i := 0; i < 10; i++ { + p.Go(func() { + count.Add(1) + }) + } + p.Wait() + require.Equal(t, int64(20), count.Load()) + }) } func BenchmarkPool(b *testing.B) { b.Run("startup and teardown", func(b *testing.B) { for i := 0; i < b.N; i++ { - p := New() + p := pool.New() p.Go(func() {}) p.Wait() } }) b.Run("per task", func(b *testing.B) { - p := New() + p := pool.New() f := func() {} for i := 0; i < b.N; i++ { p.Go(f) diff --git a/pool/result_context_pool.go b/pool/result_context_pool.go index 55dc3bc..6bc30dd 100644 --- a/pool/result_context_pool.go +++ b/pool/result_context_pool.go @@ -20,11 +20,10 @@ type ResultContextPool[T any] struct { // Go submits a task to the pool. If all goroutines in the pool // are busy, a call to Go() will block until the task can be started. func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) { + idx := p.agg.nextIndex() p.contextPool.Go(func(ctx context.Context) error { res, err := f(ctx) - if err == nil || p.collectErrored { - p.agg.add(res) - } + p.agg.save(idx, res, err != nil) return err }) } @@ -33,7 +32,9 @@ func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) { // returns an error if any of the tasks errored. func (p *ResultContextPool[T]) Wait() ([]T, error) { err := p.contextPool.Wait() - return p.agg.results, err + results := p.agg.collect(p.collectErrored) + p.agg = resultAggregator[T]{} + return results, err } // WithCollectErrored configures the pool to still collect the result of a task @@ -62,6 +63,15 @@ func (p *ResultContextPool[T]) WithCancelOnError() *ResultContextPool[T] { return p } +// WithFailFast is an alias for the combination of WithFirstError and +// WithCancelOnError. By default, the errors from all tasks are returned and +// the pool's context is not canceled until the parent context is canceled. +func (p *ResultContextPool[T]) WithFailFast() *ResultContextPool[T] { + p.panicIfInitialized() + p.contextPool.WithFailFast() + return p +} + // WithMaxGoroutines limits the number of goroutines in a pool. // Defaults to unlimited. Panics if n < 1. func (p *ResultContextPool[T]) WithMaxGoroutines(n int) *ResultContextPool[T] { diff --git a/pool/result_context_pool_test.go b/pool/result_context_pool_test.go index a116d93..fc3b68a 100644 --- a/pool/result_context_pool_test.go +++ b/pool/result_context_pool_test.go @@ -1,15 +1,16 @@ -package pool +package pool_test import ( "context" "errors" "fmt" - "sort" "strconv" "sync/atomic" "testing" "time" + "github.com/sourcegraph/conc/pool" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,14 +24,14 @@ func TestResultContextPool(t *testing.T) { t.Run("panics on configuration after init", func(t *testing.T) { t.Run("before wait", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithContext(context.Background()) + g := pool.NewWithResults[int]().WithContext(context.Background()) g.Go(func(context.Context) (int, error) { return 0, nil }) require.Panics(t, func() { g.WithMaxGoroutines(10) }) }) t.Run("after wait", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithContext(context.Background()) + g := pool.NewWithResults[int]().WithContext(context.Background()) g.Go(func(context.Context) (int, error) { return 0, nil }) _, _ = g.Wait() require.Panics(t, func() { g.WithMaxGoroutines(10) }) @@ -42,7 +43,7 @@ func TestResultContextPool(t *testing.T) { bgctx := context.Background() t.Run("wait returns no error if no errors", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithContext(bgctx) + g := pool.NewWithResults[int]().WithContext(bgctx) g.Go(func(context.Context) (int, error) { return 0, nil }) res, err := g.Wait() require.Len(t, res, 1) @@ -51,7 +52,7 @@ func TestResultContextPool(t *testing.T) { t.Run("wait error if func returns error", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithContext(bgctx) + g := pool.NewWithResults[int]().WithContext(bgctx) g.Go(func(context.Context) (int, error) { return 0, err1 }) res, err := g.Wait() require.Len(t, res, 0) @@ -60,7 +61,7 @@ func TestResultContextPool(t *testing.T) { t.Run("wait error is all returned errors", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithErrors().WithContext(bgctx) + g := pool.NewWithResults[int]().WithErrors().WithContext(bgctx) g.Go(func(context.Context) (int, error) { return 0, err1 }) g.Go(func(context.Context) (int, error) { return 0, nil }) g.Go(func(context.Context) (int, error) { return 0, err2 }) @@ -74,7 +75,7 @@ func TestResultContextPool(t *testing.T) { t.Run("context cancel propagates", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) - g := NewWithResults[int]().WithContext(ctx) + g := pool.NewWithResults[int]().WithContext(ctx) g.Go(func(ctx context.Context) (int, error) { <-ctx.Done() return 0, ctx.Err() @@ -87,7 +88,7 @@ func TestResultContextPool(t *testing.T) { t.Run("WithCancelOnError", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithContext(context.Background()).WithCancelOnError() + g := pool.NewWithResults[int]().WithContext(context.Background()).WithCancelOnError() g.Go(func(ctx context.Context) (int, error) { <-ctx.Done() return 0, ctx.Err() @@ -101,9 +102,25 @@ func TestResultContextPool(t *testing.T) { require.ErrorIs(t, err, err1) }) + t.Run("WithFailFast", func(t *testing.T) { + t.Parallel() + p := pool.NewWithResults[int]().WithContext(context.Background()).WithFailFast() + p.Go(func(ctx context.Context) (int, error) { + return 0, err1 + }) + p.Go(func(ctx context.Context) (int, error) { + <-ctx.Done() + return 1, ctx.Err() + }) + results, err := p.Wait() + require.ErrorIs(t, err, err1) + require.NotErrorIs(t, err, context.Canceled) + require.Empty(t, results) + }) + t.Run("WithCancelOnError and panic", func(t *testing.T) { t.Parallel() - p := NewWithResults[int](). + p := pool.NewWithResults[int](). WithContext(context.Background()). WithCancelOnError() var cancelledTasks atomic.Int64 @@ -126,7 +143,7 @@ func TestResultContextPool(t *testing.T) { t.Run("no WithCancelOnError", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithContext(context.Background()) + g := pool.NewWithResults[int]().WithContext(context.Background()) g.Go(func(ctx context.Context) (int, error) { select { case <-ctx.Done(): @@ -146,7 +163,7 @@ func TestResultContextPool(t *testing.T) { t.Run("WithCollectErrored", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithContext(context.Background()).WithCollectErrored() + g := pool.NewWithResults[int]().WithContext(context.Background()).WithCollectErrored() g.Go(func(context.Context) (int, error) { return 0, err1 }) res, err := g.Wait() require.Len(t, res, 1) // errored value is collected @@ -155,7 +172,7 @@ func TestResultContextPool(t *testing.T) { t.Run("WithFirstError", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithContext(context.Background()).WithFirstError() + g := pool.NewWithResults[int]().WithContext(context.Background()).WithFirstError() sync := make(chan struct{}) g.Go(func(ctx context.Context) (int, error) { defer close(sync) @@ -186,7 +203,7 @@ func TestResultContextPool(t *testing.T) { t.Parallel() ctx := context.Background() - g := NewWithResults[int]().WithContext(ctx).WithMaxGoroutines(maxConcurrency) + g := pool.NewWithResults[int]().WithContext(ctx).WithMaxGoroutines(maxConcurrency) var currentConcurrent atomic.Int64 taskCount := maxConcurrency * 10 @@ -205,11 +222,26 @@ func TestResultContextPool(t *testing.T) { }) } res, err := g.Wait() - sort.Ints(res) require.Equal(t, expected, res) require.NoError(t, err) require.Equal(t, int64(0), currentConcurrent.Load()) }) } }) + + t.Run("reuse", func(t *testing.T) { + // Test for https://github.com/sourcegraph/conc/issues/128 + p := pool.NewWithResults[int]().WithContext(context.Background()) + + p.Go(func(context.Context) (int, error) { return 1, err1 }) + results1, errs1 := p.Wait() + require.Empty(t, results1) + require.ErrorIs(t, errs1, err1) + + p.Go(func(context.Context) (int, error) { return 2, err2 }) + results2, errs2 := p.Wait() + require.Empty(t, results2) + require.ErrorIs(t, errs2, err2) + require.NotErrorIs(t, errs2, err1) + }) } diff --git a/pool/result_error_pool.go b/pool/result_error_pool.go index 4caaadc..832cd9b 100644 --- a/pool/result_error_pool.go +++ b/pool/result_error_pool.go @@ -8,9 +8,8 @@ import ( // type and an error. Tasks are executed in the pool with Go(), then the // results of the tasks are returned by Wait(). // -// The order of the results is not guaranteed to be the same as the order the -// tasks were submitted. If your use case requires consistent ordering, -// consider using the `stream` package or `Map` from the `iter` package. +// The order of the results is guaranteed to be the same as the order the +// tasks were submitted. // // The configuration methods (With*) will panic if they are used after calling // Go() for the first time. @@ -23,11 +22,10 @@ type ResultErrorPool[T any] struct { // Go submits a task to the pool. If all goroutines in the pool // are busy, a call to Go() will block until the task can be started. func (p *ResultErrorPool[T]) Go(f func() (T, error)) { + idx := p.agg.nextIndex() p.errorPool.Go(func() error { res, err := f() - if err == nil || p.collectErrored { - p.agg.add(res) - } + p.agg.save(idx, res, err != nil) return err }) } @@ -36,7 +34,9 @@ func (p *ResultErrorPool[T]) Go(f func() (T, error)) { // returning the results and any errors from tasks. func (p *ResultErrorPool[T]) Wait() ([]T, error) { err := p.errorPool.Wait() - return p.agg.results, err + results := p.agg.collect(p.collectErrored) + p.agg = resultAggregator[T]{} // reset for reuse + return results, err } // WithCollectErrored configures the pool to still collect the result of a task diff --git a/pool/result_error_pool_test.go b/pool/result_error_pool_test.go index 6b9ba5b..7326639 100644 --- a/pool/result_error_pool_test.go +++ b/pool/result_error_pool_test.go @@ -1,4 +1,4 @@ -package pool +package pool_test import ( "errors" @@ -8,10 +8,12 @@ import ( "testing" "time" + "github.com/sourcegraph/conc/pool" + "github.com/stretchr/testify/require" ) -func TestResultErrorGroup(t *testing.T) { +func TestResultErrorPool(t *testing.T) { t.Parallel() err1 := errors.New("err1") @@ -20,14 +22,14 @@ func TestResultErrorGroup(t *testing.T) { t.Run("panics on configuration after init", func(t *testing.T) { t.Run("before wait", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithErrors() + g := pool.NewWithResults[int]().WithErrors() g.Go(func() (int, error) { return 0, nil }) require.Panics(t, func() { g.WithMaxGoroutines(10) }) }) t.Run("after wait", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithErrors() + g := pool.NewWithResults[int]().WithErrors() g.Go(func() (int, error) { return 0, nil }) _, _ = g.Wait() require.Panics(t, func() { g.WithMaxGoroutines(10) }) @@ -36,7 +38,7 @@ func TestResultErrorGroup(t *testing.T) { t.Run("wait returns no error if no errors", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithErrors() + g := pool.NewWithResults[int]().WithErrors() g.Go(func() (int, error) { return 1, nil }) res, err := g.Wait() require.NoError(t, err) @@ -45,7 +47,7 @@ func TestResultErrorGroup(t *testing.T) { t.Run("wait error if func returns error", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithErrors() + g := pool.NewWithResults[int]().WithErrors() g.Go(func() (int, error) { return 0, err1 }) res, err := g.Wait() require.Len(t, res, 0) // errored value is ignored @@ -54,7 +56,7 @@ func TestResultErrorGroup(t *testing.T) { t.Run("WithCollectErrored", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithErrors().WithCollectErrored() + g := pool.NewWithResults[int]().WithErrors().WithCollectErrored() g.Go(func() (int, error) { return 0, err1 }) res, err := g.Wait() require.Len(t, res, 1) // errored value is collected @@ -63,7 +65,7 @@ func TestResultErrorGroup(t *testing.T) { t.Run("WithFirstError", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithErrors().WithFirstError() + g := pool.NewWithResults[int]().WithErrors().WithFirstError() synchronizer := make(chan struct{}) g.Go(func() (int, error) { <-synchronizer @@ -89,7 +91,7 @@ func TestResultErrorGroup(t *testing.T) { t.Run("wait error is all returned errors", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]().WithErrors() + g := pool.NewWithResults[int]().WithErrors() g.Go(func() (int, error) { return 0, err1 }) g.Go(func() (int, error) { return 0, nil }) g.Go(func() (int, error) { return 0, err2 }) @@ -106,7 +108,7 @@ func TestResultErrorGroup(t *testing.T) { maxConcurrency := maxConcurrency // copy t.Parallel() - g := NewWithResults[int]().WithErrors().WithMaxGoroutines(maxConcurrency) + g := pool.NewWithResults[int]().WithErrors().WithMaxGoroutines(maxConcurrency) var currentConcurrent atomic.Int64 taskCount := maxConcurrency * 10 @@ -128,4 +130,20 @@ func TestResultErrorGroup(t *testing.T) { }) } }) + + t.Run("reuse", func(t *testing.T) { + // Test for https://github.com/sourcegraph/conc/issues/128 + p := pool.NewWithResults[int]().WithErrors() + + p.Go(func() (int, error) { return 1, err1 }) + results1, errs1 := p.Wait() + require.Empty(t, results1) + require.ErrorIs(t, errs1, err1) + + p.Go(func() (int, error) { return 2, err2 }) + results2, errs2 := p.Wait() + require.Empty(t, results2) + require.ErrorIs(t, errs2, err2) + require.NotErrorIs(t, errs2, err1) + }) } diff --git a/pool/result_pool.go b/pool/result_pool.go index ea304cb..f73a772 100644 --- a/pool/result_pool.go +++ b/pool/result_pool.go @@ -2,6 +2,7 @@ package pool import ( "context" + "sort" "sync" ) @@ -19,9 +20,8 @@ func NewWithResults[T any]() *ResultPool[T] { // Tasks are executed in the pool with Go(), then the results of the tasks are // returned by Wait(). // -// The order of the results is not guaranteed to be the same as the order the -// tasks were submitted. If your use case requires consistent ordering, -// consider using the `stream` package or `Map` from the `iter` package. +// The order of the results is guaranteed to be the same as the order the +// tasks were submitted. type ResultPool[T any] struct { pool Pool agg resultAggregator[T] @@ -30,8 +30,9 @@ type ResultPool[T any] struct { // Go submits a task to the pool. If all goroutines in the pool // are busy, a call to Go() will block until the task can be started. func (p *ResultPool[T]) Go(f func() T) { + idx := p.agg.nextIndex() p.pool.Go(func() { - p.agg.add(f()) + p.agg.save(idx, f(), false) }) } @@ -39,7 +40,9 @@ func (p *ResultPool[T]) Go(f func() T) { // a slice of results from tasks that did not panic. func (p *ResultPool[T]) Wait() []T { p.pool.Wait() - return p.agg.results + results := p.agg.collect(true) + p.agg = resultAggregator[T]{} // reset for reuse + return results } // MaxGoroutines returns the maximum size of the pool. @@ -83,11 +86,57 @@ func (p *ResultPool[T]) panicIfInitialized() { // goroutines. The zero value is valid and ready to use. type resultAggregator[T any] struct { mu sync.Mutex + len int results []T + errored []int } -func (r *resultAggregator[T]) add(res T) { +// nextIndex reserves a slot for a result. The returned value should be passed +// to save() when adding a result to the aggregator. +func (r *resultAggregator[T]) nextIndex() int { r.mu.Lock() - r.results = append(r.results, res) - r.mu.Unlock() + defer r.mu.Unlock() + + nextIdx := r.len + r.len += 1 + return nextIdx +} + +func (r *resultAggregator[T]) save(i int, res T, errored bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if i >= len(r.results) { + old := r.results + r.results = make([]T, r.len) + copy(r.results, old) + } + + r.results[i] = res + + if errored { + r.errored = append(r.errored, i) + } +} + +// collect returns the set of aggregated results. +func (r *resultAggregator[T]) collect(collectErrored bool) []T { + if !r.mu.TryLock() { + panic("collect should not be called until all goroutines have exited") + } + + if collectErrored || len(r.errored) == 0 { + return r.results + } + + filtered := r.results[:0] + sort.Ints(r.errored) + for i, e := range r.errored { + if i == 0 { + filtered = append(filtered, r.results[:e]...) + } else { + filtered = append(filtered, r.results[r.errored[i-1]+1:e]...) + } + } + return filtered } diff --git a/pool/result_pool_test.go b/pool/result_pool_test.go index 1663968..69b9de4 100644 --- a/pool/result_pool_test.go +++ b/pool/result_pool_test.go @@ -1,18 +1,20 @@ -package pool +package pool_test import ( "fmt" - "sort" + "math/rand" "strconv" "sync/atomic" "testing" "time" + "github.com/sourcegraph/conc/pool" + "github.com/stretchr/testify/require" ) func ExampleResultPool() { - p := NewWithResults[int]() + p := pool.NewWithResults[int]() for i := 0; i < 10; i++ { i := i p.Go(func() int { @@ -20,8 +22,6 @@ func ExampleResultPool() { }) } res := p.Wait() - // Result order is nondeterministic, so sort them first - sort.Ints(res) fmt.Println(res) // Output: @@ -34,22 +34,23 @@ func TestResultGroup(t *testing.T) { t.Run("panics on configuration after init", func(t *testing.T) { t.Run("before wait", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]() + g := pool.NewWithResults[int]() g.Go(func() int { return 0 }) require.Panics(t, func() { g.WithMaxGoroutines(10) }) }) t.Run("after wait", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]() + g := pool.NewWithResults[int]() g.Go(func() int { return 0 }) + _ = g.Wait() require.Panics(t, func() { g.WithMaxGoroutines(10) }) }) }) t.Run("basic", func(t *testing.T) { t.Parallel() - g := NewWithResults[int]() + g := pool.NewWithResults[int]() expected := []int{} for i := 0; i < 100; i++ { i := i @@ -59,15 +60,34 @@ func TestResultGroup(t *testing.T) { }) } res := g.Wait() - sort.Ints(res) require.Equal(t, expected, res) }) + t.Run("deterministic order", func(t *testing.T) { + t.Parallel() + p := pool.NewWithResults[int]() + results := make([]int, 100) + for i := 0; i < 100; i++ { + results[i] = i + } + for _, result := range results { + result := result + p.Go(func() int { + // Add a random sleep to make it exceedingly unlikely that the + // results are returned in the order they are submitted. + time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond) + return result + }) + } + got := p.Wait() + require.Equal(t, results, got) + }) + t.Run("limit", func(t *testing.T) { t.Parallel() for _, maxGoroutines := range []int{1, 10, 100} { t.Run(strconv.Itoa(maxGoroutines), func(t *testing.T) { - g := NewWithResults[int]().WithMaxGoroutines(maxGoroutines) + g := pool.NewWithResults[int]().WithMaxGoroutines(maxGoroutines) var currentConcurrent atomic.Int64 var errCount atomic.Int64 @@ -87,11 +107,23 @@ func TestResultGroup(t *testing.T) { }) } res := g.Wait() - sort.Ints(res) require.Equal(t, expected, res) require.Equal(t, int64(0), errCount.Load()) require.Equal(t, int64(0), currentConcurrent.Load()) }) } }) + + t.Run("reuse", func(t *testing.T) { + // Test for https://github.com/sourcegraph/conc/issues/128 + p := pool.NewWithResults[int]() + + p.Go(func() int { return 1 }) + results1 := p.Wait() + require.Equal(t, []int{1}, results1) + + p.Go(func() int { return 2 }) + results2 := p.Wait() + require.Equal(t, []int{2}, results2) + }) } diff --git a/stream/stream.go b/stream/stream.go index d80a923..6b11e90 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -128,7 +128,9 @@ func (s *Stream) callbacker() { callback := <-callbackCh // Execute the callback (with panic protection). - panicCatcher.Try(callback) + if callback != nil { + panicCatcher.Try(callback) + } // Return the channel to the pool of unused channels. putCh(callbackCh) diff --git a/stream/stream_test.go b/stream/stream_test.go index 48c8c85..9f5bce1 100644 --- a/stream/stream_test.go +++ b/stream/stream_test.go @@ -1,4 +1,4 @@ -package stream +package stream_test import ( "fmt" @@ -6,22 +6,24 @@ import ( "testing" "time" + "github.com/sourcegraph/conc/stream" + "github.com/stretchr/testify/require" ) func ExampleStream() { times := []int{20, 52, 16, 45, 4, 80} - stream := New() + s := stream.New() for _, millis := range times { dur := time.Duration(millis) * time.Millisecond - stream.Go(func() Callback { + s.Go(func() stream.Callback { time.Sleep(dur) // This will print in the order the tasks were submitted return func() { fmt.Println(dur) } }) } - stream.Wait() + s.Wait() // Output: // 20ms @@ -37,11 +39,11 @@ func TestStream(t *testing.T) { t.Run("simple", func(t *testing.T) { t.Parallel() - s := New() + s := stream.New() var res []int for i := 0; i < 5; i++ { i := i - s.Go(func() Callback { + s.Go(func() stream.Callback { i *= 2 return func() { res = append(res, i) @@ -52,13 +54,27 @@ func TestStream(t *testing.T) { require.Equal(t, []int{0, 2, 4, 6, 8}, res) }) + t.Run("nil callback", func(t *testing.T) { + t.Parallel() + s := stream.New() + var totalCount atomic.Int64 + for i := 0; i < 5; i++ { + s.Go(func() stream.Callback { + totalCount.Add(1) + return nil + }) + } + s.Wait() + require.Equal(t, int64(5), totalCount.Load()) + }) + t.Run("max goroutines", func(t *testing.T) { t.Parallel() - s := New().WithMaxGoroutines(5) + s := stream.New().WithMaxGoroutines(5) var currentTaskCount atomic.Int64 var currentCallbackCount atomic.Int64 for i := 0; i < 50; i++ { - s.Go(func() Callback { + s.Go(func() stream.Callback { curr := currentTaskCount.Add(1) if curr > 5 { t.Fatal("too many concurrent tasks being executed") @@ -84,8 +100,8 @@ func TestStream(t *testing.T) { t.Run("panic in task is propagated", func(t *testing.T) { t.Parallel() - s := New().WithMaxGoroutines(5) - s.Go(func() Callback { + s := stream.New().WithMaxGoroutines(5) + s.Go(func() stream.Callback { panic("something really bad happened in the task") }) require.Panics(t, s.Wait) @@ -93,8 +109,8 @@ func TestStream(t *testing.T) { t.Run("panic in callback is propagated", func(t *testing.T) { t.Parallel() - s := New().WithMaxGoroutines(5) - s.Go(func() Callback { + s := stream.New().WithMaxGoroutines(5) + s.Go(func() stream.Callback { return func() { panic("something really bad happened in the callback") } @@ -104,14 +120,14 @@ func TestStream(t *testing.T) { t.Run("panic in callback does not block producers", func(t *testing.T) { t.Parallel() - s := New().WithMaxGoroutines(5) - s.Go(func() Callback { + s := stream.New().WithMaxGoroutines(5) + s.Go(func() stream.Callback { return func() { panic("something really bad happened in the callback") } }) for i := 0; i < 100; i++ { - s.Go(func() Callback { + s.Go(func() stream.Callback { return func() {} }) } @@ -122,17 +138,17 @@ func TestStream(t *testing.T) { func BenchmarkStream(b *testing.B) { b.Run("startup and teardown", func(b *testing.B) { for i := 0; i < b.N; i++ { - s := New() - s.Go(func() Callback { return func() {} }) + s := stream.New() + s.Go(func() stream.Callback { return func() {} }) s.Wait() } }) b.Run("per task", func(b *testing.B) { n := 0 - s := New() + s := stream.New() for i := 0; i < b.N; i++ { - s.Go(func() Callback { + s.Go(func() stream.Callback { return func() { n += 1 } diff --git a/waitgroup_test.go b/waitgroup_test.go index 1cc808b..44ae61a 100644 --- a/waitgroup_test.go +++ b/waitgroup_test.go @@ -1,17 +1,19 @@ -package conc +package conc_test import ( "fmt" "sync/atomic" "testing" + "github.com/sourcegraph/conc" + "github.com/stretchr/testify/require" ) func ExampleWaitGroup() { var count atomic.Int64 - var wg WaitGroup + var wg conc.WaitGroup for i := 0; i < 10; i++ { wg.Go(func() { count.Add(1) @@ -25,7 +27,7 @@ func ExampleWaitGroup() { } func ExampleWaitGroup_WaitAndRecover() { - var wg WaitGroup + var wg conc.WaitGroup wg.Go(func() { panic("super bad thing") @@ -42,14 +44,14 @@ func TestWaitGroup(t *testing.T) { t.Run("ctor", func(t *testing.T) { t.Parallel() - wg := NewWaitGroup() - require.IsType(t, &WaitGroup{}, wg) + wg := conc.NewWaitGroup() + require.IsType(t, &conc.WaitGroup{}, wg) }) t.Run("all spawned run", func(t *testing.T) { t.Parallel() var count atomic.Int64 - var wg WaitGroup + var wg conc.WaitGroup for i := 0; i < 100; i++ { wg.Go(func() { count.Add(1) @@ -64,7 +66,7 @@ func TestWaitGroup(t *testing.T) { t.Run("is propagated", func(t *testing.T) { t.Parallel() - var wg WaitGroup + var wg conc.WaitGroup wg.Go(func() { panic("super bad thing") }) @@ -73,7 +75,7 @@ func TestWaitGroup(t *testing.T) { t.Run("one is propagated", func(t *testing.T) { t.Parallel() - var wg WaitGroup + var wg conc.WaitGroup wg.Go(func() { panic("super bad thing") }) @@ -85,7 +87,7 @@ func TestWaitGroup(t *testing.T) { t.Run("non-panics do not overwrite panic", func(t *testing.T) { t.Parallel() - var wg WaitGroup + var wg conc.WaitGroup wg.Go(func() { panic("super bad thing") }) @@ -97,7 +99,7 @@ func TestWaitGroup(t *testing.T) { t.Run("non-panics run successfully", func(t *testing.T) { t.Parallel() - var wg WaitGroup + var wg conc.WaitGroup var i atomic.Int64 wg.Go(func() { i.Add(1) @@ -114,7 +116,7 @@ func TestWaitGroup(t *testing.T) { t.Run("is caught by waitandrecover", func(t *testing.T) { t.Parallel() - var wg WaitGroup + var wg conc.WaitGroup wg.Go(func() { panic("super bad thing") }) @@ -124,7 +126,7 @@ func TestWaitGroup(t *testing.T) { t.Run("one is caught by waitandrecover", func(t *testing.T) { t.Parallel() - var wg WaitGroup + var wg conc.WaitGroup wg.Go(func() { panic("super bad thing") }) @@ -137,7 +139,7 @@ func TestWaitGroup(t *testing.T) { t.Run("nonpanics run successfully with waitandrecover", func(t *testing.T) { t.Parallel() - var wg WaitGroup + var wg conc.WaitGroup var i atomic.Int64 wg.Go(func() { i.Add(1)