// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package http_test

import (
	"context"
	"io"
	"net"
	"net/http"
	"net/http/httptrace"
	"testing"
)

func TestTransportPoolConnReusePriorConnection(t *testing.T) {
	dt := newTransportDialTester(t, http1Mode)

	// First request creates a new connection.
	rt1 := dt.roundTrip()
	c1 := dt.wantDial()
	c1.finish(nil)
	rt1.wantDone(c1)
	rt1.finish()

	// Second request reuses the first connection.
	rt2 := dt.roundTrip()
	rt2.wantDone(c1)
	rt2.finish()
}

func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
	dt := newTransportDialTester(t, http1Mode)

	// First request creates a new connection.
	rt1 := dt.roundTrip()
	c1 := dt.wantDial()
	c1.finish(nil)
	rt1.wantDone(c1)

	// Second request is made while the first request is still using its connection,
	// so it goes on a new connection.
	rt2 := dt.roundTrip()
	c2 := dt.wantDial()
	c2.finish(nil)
	rt2.wantDone(c2)
}

func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
	dt := newTransportDialTester(t, http1Mode)

	// First request creates a new connection.
	rt1 := dt.roundTrip()
	c1 := dt.wantDial()
	c1.finish(nil)
	rt1.wantDone(c1)

	// Second request is made while the first request is still using its connection.
	// The first connection completes while the second Dial is in progress, so the
	// second request uses the first connection.
	rt2 := dt.roundTrip()
	c2 := dt.wantDial()
	rt1.finish()
	rt2.wantDone(c1)

	// This section is a bit overfitted to the current Transport implementation:
	// A third request starts. We have an in-progress dial that was started by rt2,
	// but this new request (rt3) is going to ignore it and make a dial of its own.
	// rt3 will use the first of these dials that completes.
	rt3 := dt.roundTrip()
	c3 := dt.wantDial()
	c2.finish(nil)
	rt3.wantDone(c2)

	c3.finish(nil)
}

// A transportDialTester manages a test of a connection's Dials.
type transportDialTester struct {
	t   *testing.T
	cst *clientServerTest

	dials chan *transportDialTesterConn // each new conn is sent to this channel

	roundTripCount int
	dialCount      int
}

// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test.
type transportDialTesterRoundTrip struct {
	t *testing.T

	roundTripID int                // distinguishes RoundTrips in logs
	cancel      context.CancelFunc // cancels the Request context
	reqBody     io.WriteCloser     // write half of the Request.Body
	finished    bool

	done chan struct{} // closed when RoundTrip returns:w
	res  *http.Response
	err  error
	conn *transportDialTesterConn
}

// A transportDialTesterConn is a client connection created by the Transport as
// part of a dial test.
type transportDialTesterConn struct {
	t *testing.T

	connID int        // distinguished Dials in logs
	ready  chan error // sent on to complete the Dial

	net.Conn
}

func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
	t.Helper()
	dt := &transportDialTester{
		t:     t,
		dials: make(chan *transportDialTesterConn),
	}
	dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Write response headers when we receive a request.
		http.NewResponseController(w).EnableFullDuplex()
		w.WriteHeader(200)
		http.NewResponseController(w).Flush()
		// Wait for the client to send the request body,
		// to synchronize with the rest of the test.
		io.ReadAll(r.Body)
	}), func(tr *http.Transport) {
		tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
			c := &transportDialTesterConn{
				t:     t,
				ready: make(chan error),
			}
			// Notify the test that a Dial has started,
			// and wait for the test to notify us that it should complete.
			dt.dials <- c
			if err := <-c.ready; err != nil {
				return nil, err
			}
			nc, err := net.Dial(network, address)
			if err != nil {
				return nil, err
			}
			// Use the *transportDialTesterConn as the net.Conn,
			// to let tests associate requests with connections.
			c.Conn = nc
			return c, err
		}
	})
	return dt
}

// roundTrip starts a RoundTrip.
// It returns immediately, without waiting for the RoundTrip call to complete.
func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
	dt.t.Helper()
	ctx, cancel := context.WithCancel(context.Background())
	pr, pw := io.Pipe()
	rt := &transportDialTesterRoundTrip{
		t:           dt.t,
		roundTripID: dt.roundTripCount,
		done:        make(chan struct{}),
		reqBody:     pw,
		cancel:      cancel,
	}
	dt.roundTripCount++
	dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
	dt.t.Cleanup(func() {
		rt.cancel()
		rt.finish()
	})
	go func() {
		ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
			GotConn: func(info httptrace.GotConnInfo) {
				rt.conn = info.Conn.(*transportDialTesterConn)
			},
		})
		req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
		req.Header.Set("Content-Type", "text/plain")
		rt.res, rt.err = dt.cst.tr.RoundTrip(req)
		dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
		close(rt.done)
	}()
	return rt
}

// wantDone indicates that a RoundTrip should have returned.
func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
	rt.t.Helper()
	<-rt.done
	if rt.err != nil {
		rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
	}
	if rt.conn != c {
		rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
	}
}

// finish completes a RoundTrip by sending the request body, consuming the response body,
// and closing the response body.
func (rt *transportDialTesterRoundTrip) finish() {
	rt.t.Helper()

	if rt.finished {
		return
	}
	rt.finished = true

	<-rt.done

	if rt.err != nil {
		return
	}
	rt.reqBody.Close()
	io.ReadAll(rt.res.Body)
	rt.res.Body.Close()
	rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
}

// wantDial waits for the Transport to start a Dial.
func (dt *transportDialTester) wantDial() *transportDialTesterConn {
	c := <-dt.dials
	c.connID = dt.dialCount
	dt.dialCount++
	dt.t.Logf("Dial %v: started", c.connID)
	return c
}

// finish completes a Dial.
func (c *transportDialTesterConn) finish(err error) {
	c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
	c.ready <- err
	close(c.ready)
}
