// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build goexperiment.rangefunc

package main

import (
	"fmt"
	"iter"
	"runtime"
)

func init() {
	register("CoroLockOSThreadIterLock", func() {
		println("expect: OK")
		CoroLockOSThread(callerExhaust, iterLock)
	})
	register("CoroLockOSThreadIterLockYield", func() {
		println("expect: OS thread locking must match")
		CoroLockOSThread(callerExhaust, iterLockYield)
	})
	register("CoroLockOSThreadLock", func() {
		println("expect: OK")
		CoroLockOSThread(callerExhaustLocked, iterSimple)
	})
	register("CoroLockOSThreadLockIterNested", func() {
		println("expect: OK")
		CoroLockOSThread(callerExhaustLocked, iterNested)
	})
	register("CoroLockOSThreadLockIterLock", func() {
		println("expect: OK")
		CoroLockOSThread(callerExhaustLocked, iterLock)
	})
	register("CoroLockOSThreadLockIterLockYield", func() {
		println("expect: OS thread locking must match")
		CoroLockOSThread(callerExhaustLocked, iterLockYield)
	})
	register("CoroLockOSThreadLockIterYieldNewG", func() {
		println("expect: OS thread locking must match")
		CoroLockOSThread(callerExhaustLocked, iterYieldNewG)
	})
	register("CoroLockOSThreadLockAfterPull", func() {
		println("expect: OS thread locking must match")
		CoroLockOSThread(callerLockAfterPull, iterSimple)
	})
	register("CoroLockOSThreadStopLocked", func() {
		println("expect: OK")
		CoroLockOSThread(callerStopLocked, iterSimple)
	})
	register("CoroLockOSThreadStopLockedIterNested", func() {
		println("expect: OK")
		CoroLockOSThread(callerStopLocked, iterNested)
	})
}

func CoroLockOSThread(driver func(iter.Seq[int]) error, seq iter.Seq[int]) {
	if err := driver(seq); err != nil {
		println("error:", err.Error())
		return
	}
	println("OK")
}

func callerExhaust(i iter.Seq[int]) error {
	next, _ := iter.Pull(i)
	for {
		v, ok := next()
		if !ok {
			break
		}
		if v != 5 {
			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
		}
	}
	return nil
}

func callerExhaustLocked(i iter.Seq[int]) error {
	runtime.LockOSThread()
	next, _ := iter.Pull(i)
	for {
		v, ok := next()
		if !ok {
			break
		}
		if v != 5 {
			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
		}
	}
	runtime.UnlockOSThread()
	return nil
}

func callerLockAfterPull(i iter.Seq[int]) error {
	n := 0
	next, _ := iter.Pull(i)
	for {
		runtime.LockOSThread()
		n++
		v, ok := next()
		if !ok {
			break
		}
		if v != 5 {
			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
		}
	}
	for range n {
		runtime.UnlockOSThread()
	}
	return nil
}

func callerStopLocked(i iter.Seq[int]) error {
	runtime.LockOSThread()
	next, stop := iter.Pull(i)
	v, _ := next()
	stop()
	if v != 5 {
		return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
	}
	runtime.UnlockOSThread()
	return nil
}

func iterSimple(yield func(int) bool) {
	for range 3 {
		if !yield(5) {
			return
		}
	}
}

func iterNested(yield func(int) bool) {
	next, stop := iter.Pull(iterSimple)
	for {
		v, ok := next()
		if ok {
			if !yield(v) {
				stop()
			}
		} else {
			return
		}
	}
}

func iterLock(yield func(int) bool) {
	for range 3 {
		runtime.LockOSThread()
		runtime.UnlockOSThread()

		if !yield(5) {
			return
		}
	}
}

func iterLockYield(yield func(int) bool) {
	for range 3 {
		runtime.LockOSThread()
		ok := yield(5)
		runtime.UnlockOSThread()
		if !ok {
			return
		}
	}
}

func iterYieldNewG(yield func(int) bool) {
	for range 3 {
		done := make(chan struct{})
		var ok bool
		go func() {
			ok = yield(5)
			done <- struct{}{}
		}()
		<-done
		if !ok {
			return
		}
	}
}
