// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package os_test

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"math/rand/v2"
	"net"
	"os"
	"runtime"
	"sync"
	"testing"

	"golang.org/x/net/nettest"
)

// Exercise sendfile/splice fast paths with a moderately large file.
//
// https://go.dev/issue/70000

func TestLargeCopyViaNetwork(t *testing.T) {
	const size = 10 * 1024 * 1024
	dir := t.TempDir()

	src, err := os.Create(dir + "/src")
	if err != nil {
		t.Fatal(err)
	}
	defer src.Close()
	if _, err := io.CopyN(src, newRandReader(), size); err != nil {
		t.Fatal(err)
	}
	if _, err := src.Seek(0, 0); err != nil {
		t.Fatal(err)
	}

	dst, err := os.Create(dir + "/dst")
	if err != nil {
		t.Fatal(err)
	}
	defer dst.Close()

	client, server := createSocketPair(t, "tcp")
	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		defer wg.Done()
		if n, err := io.Copy(dst, server); n != size || err != nil {
			t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size)
		}
	}()
	go func() {
		defer wg.Done()
		defer client.Close()
		if n, err := io.Copy(client, src); n != size || err != nil {
			t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size)
		}
	}()
	wg.Wait()

	if _, err := dst.Seek(0, 0); err != nil {
		t.Fatal(err)
	}
	if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil {
		t.Fatal(err)
	}
}

func TestCopyFileToFile(t *testing.T) {
	const size = 1 * 1024 * 1024
	dir := t.TempDir()

	src, err := os.Create(dir + "/src")
	if err != nil {
		t.Fatal(err)
	}
	defer src.Close()
	if _, err := io.CopyN(src, newRandReader(), size); err != nil {
		t.Fatal(err)
	}
	if _, err := src.Seek(0, 0); err != nil {
		t.Fatal(err)
	}

	mustSeek := func(f *os.File, offset int64, whence int) int64 {
		ret, err := f.Seek(offset, whence)
		if err != nil {
			t.Fatal(err)
		}
		return ret
	}

	for _, srcStart := range []int64{0, 100, size} {
		remaining := size - srcStart
		for _, dstStart := range []int64{0, 200} {
			for _, limit := range []int64{remaining, remaining - 100, size * 2, 0} {
				if limit < 0 {
					continue
				}
				name := fmt.Sprintf("srcStart=%v/dstStart=%v/limit=%v", srcStart, dstStart, limit)
				t.Run(name, func(t *testing.T) {
					dst, err := os.CreateTemp(dir, "dst")
					if err != nil {
						t.Fatal(err)
					}
					defer dst.Close()
					defer os.Remove(dst.Name())

					mustSeek(src, srcStart, io.SeekStart)
					if _, err := io.CopyN(dst, zeroReader{}, dstStart); err != nil {
						t.Fatal(err)
					}

					var copied int64
					if limit == 0 {
						copied, err = io.Copy(dst, src)
					} else {
						copied, err = io.CopyN(dst, src, limit)
					}
					if limit > remaining {
						if err != io.EOF {
							t.Errorf("Copy: %v; want io.EOF", err)
						}
					} else {
						if err != nil {
							t.Errorf("Copy: %v; want nil", err)
						}
					}

					wantCopied := remaining
					if limit != 0 {
						wantCopied = min(limit, wantCopied)
					}
					if copied != wantCopied {
						t.Errorf("copied %v bytes, want %v", copied, wantCopied)
					}

					srcPos := mustSeek(src, 0, io.SeekCurrent)
					wantSrcPos := srcStart + wantCopied
					if srcPos != wantSrcPos {
						t.Errorf("source position = %v, want %v", srcPos, wantSrcPos)
					}

					dstPos := mustSeek(dst, 0, io.SeekCurrent)
					wantDstPos := dstStart + wantCopied
					if dstPos != wantDstPos {
						t.Errorf("destination position = %v, want %v", dstPos, wantDstPos)
					}

					mustSeek(dst, 0, io.SeekStart)
					rr := newRandReader()
					io.CopyN(io.Discard, rr, srcStart)
					wantReader := io.MultiReader(
						io.LimitReader(zeroReader{}, dstStart),
						io.LimitReader(rr, wantCopied),
					)
					if err := compareReaders(dst, wantReader); err != nil {
						t.Fatal(err)
					}
				})

			}
		}
	}
}

func compareReaders(a, b io.Reader) error {
	bufa := make([]byte, 4096)
	bufb := make([]byte, 4096)
	off := 0
	for {
		na, erra := io.ReadFull(a, bufa)
		if erra != nil && erra != io.EOF && erra != io.ErrUnexpectedEOF {
			return erra
		}
		nb, errb := io.ReadFull(b, bufb)
		if errb != nil && errb != io.EOF && errb != io.ErrUnexpectedEOF {
			return errb
		}
		if !bytes.Equal(bufa[:na], bufb[:nb]) {
			return errors.New("contents mismatch")
		}
		if erra != nil && errb != nil {
			break
		}
		off += len(bufa)
	}
	return nil
}

type zeroReader struct{}

func (r zeroReader) Read(p []byte) (int, error) {
	clear(p)
	return len(p), nil
}

type randReader struct {
	rand *rand.Rand
}

func newRandReader() *randReader {
	return &randReader{rand.New(rand.NewPCG(0, 0))}
}

func (r *randReader) Read(p []byte) (int, error) {
	for i := range p {
		p[i] = byte(r.rand.Uint32() & 0xff)
	}
	return len(p), nil
}

func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
	t.Helper()
	if !nettest.TestableNetwork(proto) {
		t.Skipf("%s does not support %q", runtime.GOOS, proto)
	}

	ln, err := nettest.NewLocalListener(proto)
	if err != nil {
		t.Fatalf("NewLocalListener error: %v", err)
	}
	t.Cleanup(func() {
		if ln != nil {
			ln.Close()
		}
		if client != nil {
			client.Close()
		}
		if server != nil {
			server.Close()
		}
	})
	ch := make(chan struct{})
	go func() {
		var err error
		server, err = ln.Accept()
		if err != nil {
			t.Errorf("Accept new connection error: %v", err)
		}
		ch <- struct{}{}
	}()
	client, err = net.Dial(proto, ln.Addr().String())
	<-ch
	if err != nil {
		t.Fatalf("Dial new connection error: %v", err)
	}
	return client, server
}
