// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package bigmod

import (
	"bufio"
	"bytes"
	cryptorand "crypto/rand"
	"encoding/hex"
	"fmt"
	"math/big"
	"math/bits"
	"math/rand"
	"os"
	"reflect"
	"slices"
	"strings"
	"testing"
	"testing/quick"
)

// setBig assigns x = n, optionally resizing n to the appropriate size.
//
// The announced length of x is set based on the actual bit size of the input,
// ignoring leading zeroes.
func (x *Nat) setBig(n *big.Int) *Nat {
	limbs := n.Bits()
	x.reset(len(limbs))
	for i := range limbs {
		x.limbs[i] = uint(limbs[i])
	}
	return x
}

func (n *Nat) asBig() *big.Int {
	bits := make([]big.Word, len(n.limbs))
	for i := range n.limbs {
		bits[i] = big.Word(n.limbs[i])
	}
	return new(big.Int).SetBits(bits)
}

func (n *Nat) String() string {
	var limbs []string
	for i := range n.limbs {
		limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
	}
	return "{" + strings.Join(limbs, " ") + "}"
}

// Generate generates an even nat. It's used by testing/quick to produce random
// *nat values for quick.Check invocations.
func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
	limbs := make([]uint, size)
	for i := 0; i < size; i++ {
		limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
	}
	return reflect.ValueOf(&Nat{limbs})
}

func testModAddCommutative(a *Nat, b *Nat) bool {
	m := maxModulus(uint(len(a.limbs)))
	aPlusB := new(Nat).set(a)
	aPlusB.Add(b, m)
	bPlusA := new(Nat).set(b)
	bPlusA.Add(a, m)
	return aPlusB.Equal(bPlusA) == 1
}

func TestModAddCommutative(t *testing.T) {
	err := quick.Check(testModAddCommutative, &quick.Config{})
	if err != nil {
		t.Error(err)
	}
}

func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
	m := maxModulus(uint(len(a.limbs)))
	original := new(Nat).set(a)
	a.Sub(b, m)
	a.Add(b, m)
	return a.Equal(original) == 1
}

func TestModSubThenAddIdentity(t *testing.T) {
	err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
	if err != nil {
		t.Error(err)
	}
}

func TestMontgomeryRoundtrip(t *testing.T) {
	err := quick.Check(func(a *Nat) bool {
		one := &Nat{make([]uint, len(a.limbs))}
		one.limbs[0] = 1
		aPlusOne := new(big.Int).SetBytes(natBytes(a))
		aPlusOne.Add(aPlusOne, big.NewInt(1))
		m, _ := NewModulus(aPlusOne.Bytes())
		monty := new(Nat).set(a)
		monty.montgomeryRepresentation(m)
		aAgain := new(Nat).set(monty)
		aAgain.montgomeryMul(monty, one, m)
		if a.Equal(aAgain) != 1 {
			t.Errorf("%v != %v", a, aAgain)
			return false
		}
		return true
	}, &quick.Config{})
	if err != nil {
		t.Error(err)
	}
}

func TestShiftIn(t *testing.T) {
	if bits.UintSize != 64 {
		t.Skip("examples are only valid in 64 bit")
	}
	examples := []struct {
		m, x, expected []byte
		y              uint64
	}{{
		m:        []byte{13},
		x:        []byte{0},
		y:        0xFFFF_FFFF_FFFF_FFFF,
		expected: []byte{2},
	}, {
		m:        []byte{13},
		x:        []byte{7},
		y:        0xFFFF_FFFF_FFFF_FFFF,
		expected: []byte{10},
	}, {
		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
		x:        make([]byte, 9),
		y:        0xFFFF_FFFF_FFFF_FFFF,
		expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
	}, {
		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
		x:        []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
		y:        0,
		expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
	}}

	for i, tt := range examples {
		m := modulusFromBytes(tt.m)
		got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
		if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
			t.Errorf("%d: got %v, expected %v", i, got, exp)
		}
	}
}

func TestModulusAndNatSizes(t *testing.T) {
	// These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
	// 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
	// limbs, if they are not, they fit in three. This can be a problem because
	// modulus strips leading zeroes and nat does not.
	m := modulusFromBytes([]byte{
		0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
	xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
	natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
	NewNat().SetBytes(xb, m)
}

func TestSetBytes(t *testing.T) {
	tests := []struct {
		m, b []byte
		fail bool
	}{{
		m: []byte{0xff, 0xff},
		b: []byte{0x00, 0x01},
	}, {
		m:    []byte{0xff, 0xff},
		b:    []byte{0xff, 0xff},
		fail: true,
	}, {
		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
		b: []byte{0x00, 0x01},
	}, {
		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
		b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
	}, {
		m:    []byte{0xff, 0xff},
		b:    []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
		fail: true,
	}, {
		m:    []byte{0xff, 0xff},
		b:    []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
		fail: true,
	}, {
		m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
		b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
	}, {
		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
		fail: true,
	}, {
		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
		b:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
		fail: true,
	}, {
		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
		fail: true,
	}, {
		m:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
		fail: true,
	}}

	for i, tt := range tests {
		m := modulusFromBytes(tt.m)
		got, err := NewNat().SetBytes(tt.b, m)
		if err != nil {
			if !tt.fail {
				t.Errorf("%d: unexpected error: %v", i, err)
			}
			continue
		}
		if tt.fail {
			t.Errorf("%d: unexpected success", i)
			continue
		}
		if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
			t.Errorf("%d: got %v, expected %v", i, got, expected)
		}
	}

	f := func(xBytes []byte) bool {
		m := maxModulus(uint(len(xBytes)*8/_W + 1))
		got, err := NewNat().SetBytes(xBytes, m)
		if err != nil {
			return false
		}
		return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
	}

	err := quick.Check(f, &quick.Config{})
	if err != nil {
		t.Error(err)
	}
}

func TestExpand(t *testing.T) {
	sliced := []uint{1, 2, 3, 4}
	examples := []struct {
		in  []uint
		n   int
		out []uint
	}{{
		[]uint{1, 2},
		4,
		[]uint{1, 2, 0, 0},
	}, {
		sliced[:2],
		4,
		[]uint{1, 2, 0, 0},
	}, {
		[]uint{1, 2},
		2,
		[]uint{1, 2},
	}}

	for i, tt := range examples {
		got := (&Nat{tt.in}).expand(tt.n)
		if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
			t.Errorf("%d: got %v, expected %v", i, got, tt.out)
		}
	}
}

func TestMod(t *testing.T) {
	m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
	x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
	out := new(Nat)
	out.Mod(x, m)
	expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
	if out.Equal(expected) != 1 {
		t.Errorf("%+v != %+v", out, expected)
	}
}

func TestModSub(t *testing.T) {
	m := modulusFromBytes([]byte{13})
	x := &Nat{[]uint{6}}
	y := &Nat{[]uint{7}}
	x.Sub(y, m)
	expected := &Nat{[]uint{12}}
	if x.Equal(expected) != 1 {
		t.Errorf("%+v != %+v", x, expected)
	}
	x.Sub(y, m)
	expected = &Nat{[]uint{5}}
	if x.Equal(expected) != 1 {
		t.Errorf("%+v != %+v", x, expected)
	}
}

func TestModAdd(t *testing.T) {
	m := modulusFromBytes([]byte{13})
	x := &Nat{[]uint{6}}
	y := &Nat{[]uint{7}}
	x.Add(y, m)
	expected := &Nat{[]uint{0}}
	if x.Equal(expected) != 1 {
		t.Errorf("%+v != %+v", x, expected)
	}
	x.Add(y, m)
	expected = &Nat{[]uint{7}}
	if x.Equal(expected) != 1 {
		t.Errorf("%+v != %+v", x, expected)
	}
}

func TestExp(t *testing.T) {
	m := modulusFromBytes([]byte{13})
	x := &Nat{[]uint{3}}
	out := &Nat{[]uint{0}}
	out.Exp(x, []byte{12}, m)
	expected := &Nat{[]uint{1}}
	if out.Equal(expected) != 1 {
		t.Errorf("%+v != %+v", out, expected)
	}
}

func TestExpShort(t *testing.T) {
	m := modulusFromBytes([]byte{13})
	x := &Nat{[]uint{3}}
	out := &Nat{[]uint{0}}
	out.ExpShortVarTime(x, 12, m)
	expected := &Nat{[]uint{1}}
	if out.Equal(expected) != 1 {
		t.Errorf("%+v != %+v", out, expected)
	}
}

// TestMulReductions tests that Mul reduces results equal or slightly greater
// than the modulus. Some Montgomery algorithms don't and need extra care to
// return correct results. See https://go.dev/issue/13907.
func TestMulReductions(t *testing.T) {
	// Two short but multi-limb primes.
	a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
	b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
	n := new(big.Int).Mul(a, b)

	N, _ := NewModulus(n.Bytes())
	A := NewNat().setBig(a).ExpandFor(N)
	B := NewNat().setBig(b).ExpandFor(N)

	if A.Mul(B, N).IsZero() != 1 {
		t.Error("a * b mod (a * b) != 0")
	}

	i := new(big.Int).ModInverse(a, b)
	N, _ = NewModulus(b.Bytes())
	A = NewNat().setBig(a).ExpandFor(N)
	I := NewNat().setBig(i).ExpandFor(N)
	one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)

	if A.Mul(I, N).Equal(one) != 1 {
		t.Error("a * inv(a) mod b != 1")
	}
}

func TestMul(t *testing.T) {
	t.Run("small", func(t *testing.T) { testMul(t, 760/8) })
	t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) })
	t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) })
	t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) })
}

func testMul(t *testing.T, n int) {
	a, b, m := make([]byte, n), make([]byte, n), make([]byte, n)
	cryptorand.Read(a)
	cryptorand.Read(b)
	cryptorand.Read(m)

	// Pick the highest as the modulus.
	if bytes.Compare(a, m) > 0 {
		a, m = m, a
	}
	if bytes.Compare(b, m) > 0 {
		b, m = m, b
	}

	M, err := NewModulus(m)
	if err != nil {
		t.Fatal(err)
	}
	A, err := NewNat().SetBytes(a, M)
	if err != nil {
		t.Fatal(err)
	}
	B, err := NewNat().SetBytes(b, M)
	if err != nil {
		t.Fatal(err)
	}

	A.Mul(B, M)
	ABytes := A.Bytes(M)

	mBig := new(big.Int).SetBytes(m)
	aBig := new(big.Int).SetBytes(a)
	bBig := new(big.Int).SetBytes(b)
	nBig := new(big.Int).Mul(aBig, bBig)
	nBig.Mod(nBig, mBig)
	nBigBytes := make([]byte, len(ABytes))
	nBig.FillBytes(nBigBytes)

	if !bytes.Equal(ABytes, nBigBytes) {
		t.Errorf("got %x, want %x", ABytes, nBigBytes)
	}
}

func TestIs(t *testing.T) {
	checkYes := func(c choice, err string) {
		t.Helper()
		if c != yes {
			t.Error(err)
		}
	}
	checkNot := func(c choice, err string) {
		t.Helper()
		if c != no {
			t.Error(err)
		}
	}

	mFour := modulusFromBytes([]byte{4})
	n, err := NewNat().SetBytes([]byte{3}, mFour)
	if err != nil {
		t.Fatal(err)
	}
	checkYes(n.IsMinusOne(mFour), "3 is not -1 mod 4")
	checkNot(n.IsZero(), "3 is zero")
	checkNot(n.IsOne(), "3 is one")
	checkYes(n.IsOdd(), "3 is not odd")
	n.SubOne(mFour)
	checkNot(n.IsMinusOne(mFour), "2 is -1 mod 4")
	checkNot(n.IsZero(), "2 is zero")
	checkNot(n.IsOne(), "2 is one")
	checkNot(n.IsOdd(), "2 is odd")
	n.SubOne(mFour)
	checkNot(n.IsMinusOne(mFour), "1 is -1 mod 4")
	checkNot(n.IsZero(), "1 is zero")
	checkYes(n.IsOne(), "1 is not one")
	checkYes(n.IsOdd(), "1 is not odd")
	n.SubOne(mFour)
	checkNot(n.IsMinusOne(mFour), "0 is -1 mod 4")
	checkYes(n.IsZero(), "0 is not zero")
	checkNot(n.IsOne(), "0 is one")
	checkNot(n.IsOdd(), "0 is odd")
	n.SubOne(mFour)
	checkYes(n.IsMinusOne(mFour), "-1 is not -1 mod 4")
	checkNot(n.IsZero(), "-1 is zero")
	checkNot(n.IsOne(), "-1 is one")
	checkYes(n.IsOdd(), "-1 mod 4 is not odd")

	mTwoLimbs := maxModulus(2)
	n, err = NewNat().SetBytes([]byte{0x01}, mTwoLimbs)
	if err != nil {
		t.Fatal(err)
	}
	if n.IsOne() != 1 {
		t.Errorf("1 is not one")
	}
}

func TestTrailingZeroBits(t *testing.T) {
	nb := new(big.Int).SetBytes([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7e})
	nb.Lsh(nb, 128)
	expected := 129
	for expected >= 0 {
		n := NewNat().setBig(nb)
		if n.TrailingZeroBitsVarTime() != uint(expected) {
			t.Errorf("%d != %d", n.TrailingZeroBitsVarTime(), expected)
		}
		nb.Rsh(nb, 1)
		expected--
	}
}

func TestRightShift(t *testing.T) {
	nb, err := cryptorand.Int(cryptorand.Reader, new(big.Int).Lsh(big.NewInt(1), 1024))
	if err != nil {
		t.Fatal(err)
	}
	for _, shift := range []uint{1, 32, 64, 128, 1024 - 128, 1024 - 64, 1024 - 32, 1024 - 1} {
		testShift := func(t *testing.T, shift uint) {
			n := NewNat().setBig(nb)
			oldLen := len(n.limbs)
			n.ShiftRightVarTime(shift)
			if len(n.limbs) != oldLen {
				t.Errorf("len(n.limbs) = %d, want %d", len(n.limbs), oldLen)
			}
			exp := new(big.Int).Rsh(nb, shift)
			if n.asBig().Cmp(exp) != 0 {
				t.Errorf("%v != %v", n.asBig(), exp)
			}
		}
		t.Run(fmt.Sprint(shift-1), func(t *testing.T) { testShift(t, shift-1) })
		t.Run(fmt.Sprint(shift), func(t *testing.T) { testShift(t, shift) })
		t.Run(fmt.Sprint(shift+1), func(t *testing.T) { testShift(t, shift+1) })
	}
}

func natBytes(n *Nat) []byte {
	return n.Bytes(maxModulus(uint(len(n.limbs))))
}

func natFromBytes(b []byte) *Nat {
	// Must not use Nat.SetBytes as it's used in TestSetBytes.
	bb := new(big.Int).SetBytes(b)
	return NewNat().setBig(bb)
}

func modulusFromBytes(b []byte) *Modulus {
	bb := new(big.Int).SetBytes(b)
	m, _ := NewModulus(bb.Bytes())
	return m
}

// maxModulus returns the biggest modulus that can fit in n limbs.
func maxModulus(n uint) *Modulus {
	b := big.NewInt(1)
	b.Lsh(b, n*_W)
	b.Sub(b, big.NewInt(1))
	m, _ := NewModulus(b.Bytes())
	return m
}

func makeBenchmarkModulus() *Modulus {
	return maxModulus(32)
}

func makeBenchmarkValue() *Nat {
	x := make([]uint, 32)
	for i := 0; i < 32; i++ {
		x[i]--
	}
	return &Nat{limbs: x}
}

func makeBenchmarkExponent() []byte {
	e := make([]byte, 256)
	for i := 0; i < 32; i++ {
		e[i] = 0xFF
	}
	return e
}

func BenchmarkModAdd(b *testing.B) {
	x := makeBenchmarkValue()
	y := makeBenchmarkValue()
	m := makeBenchmarkModulus()

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		x.Add(y, m)
	}
}

func BenchmarkModSub(b *testing.B) {
	x := makeBenchmarkValue()
	y := makeBenchmarkValue()
	m := makeBenchmarkModulus()

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		x.Sub(y, m)
	}
}

func BenchmarkMontgomeryRepr(b *testing.B) {
	x := makeBenchmarkValue()
	m := makeBenchmarkModulus()

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		x.montgomeryRepresentation(m)
	}
}

func BenchmarkMontgomeryMul(b *testing.B) {
	x := makeBenchmarkValue()
	y := makeBenchmarkValue()
	out := makeBenchmarkValue()
	m := makeBenchmarkModulus()

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		out.montgomeryMul(x, y, m)
	}
}

func BenchmarkModMul(b *testing.B) {
	x := makeBenchmarkValue()
	y := makeBenchmarkValue()
	m := makeBenchmarkModulus()

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		x.Mul(y, m)
	}
}

func BenchmarkExpBig(b *testing.B) {
	out := new(big.Int)
	exponentBytes := makeBenchmarkExponent()
	x := new(big.Int).SetBytes(exponentBytes)
	e := new(big.Int).SetBytes(exponentBytes)
	n := new(big.Int).SetBytes(exponentBytes)
	one := new(big.Int).SetUint64(1)
	n.Add(n, one)

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		out.Exp(x, e, n)
	}
}

func BenchmarkExp(b *testing.B) {
	x := makeBenchmarkValue()
	e := makeBenchmarkExponent()
	out := makeBenchmarkValue()
	m := makeBenchmarkModulus()

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		out.Exp(x, e, m)
	}
}

func TestNewModulus(t *testing.T) {
	expected := "modulus must be > 1"
	_, err := NewModulus([]byte{})
	if err == nil || err.Error() != expected {
		t.Errorf("NewModulus(0) got %q, want %q", err, expected)
	}
	_, err = NewModulus([]byte{0})
	if err == nil || err.Error() != expected {
		t.Errorf("NewModulus(0) got %q, want %q", err, expected)
	}
	_, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
	if err == nil || err.Error() != expected {
		t.Errorf("NewModulus(0) got %q, want %q", err, expected)
	}
	_, err = NewModulus([]byte{1})
	if err == nil || err.Error() != expected {
		t.Errorf("NewModulus(1) got %q, want %q", err, expected)
	}
	_, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
	if err == nil || err.Error() != expected {
		t.Errorf("NewModulus(1) got %q, want %q", err, expected)
	}
}

func makeTestValue(nbits int) []uint {
	n := nbits / _W
	x := make([]uint, n)
	for i := range n {
		x[i]--
	}
	return x
}

func TestAddMulVVWSized(t *testing.T) {
	// Sized addMulVVW have architecture-specific implementations on
	// a number of architectures. Test that they match the generic
	// implementation.
	tests := []struct {
		n int
		f func(z, x *uint, y uint) uint
	}{
		{1024, addMulVVW1024},
		{1536, addMulVVW1536},
		{2048, addMulVVW2048},
	}
	for _, test := range tests {
		t.Run(fmt.Sprint(test.n), func(t *testing.T) {
			x := makeTestValue(test.n)
			z := makeTestValue(test.n)
			z2 := slices.Clone(z)
			var y uint
			y--
			c := addMulVVW(z, x, y)
			c2 := test.f(&z2[0], &x[0], y)
			if !slices.Equal(z, z2) || c != c2 {
				t.Errorf("%016X, %016X != %016X, %016X", z, c, z2, c2)
			}
		})
	}
}

func TestInverse(t *testing.T) {
	f, err := os.Open("testdata/mod_inv_tests.txt")
	if err != nil {
		t.Fatal(err)
	}

	var ModInv, A, M string
	var lineNum int
	scanner := bufio.NewScanner(f)
	for scanner.Scan() {
		lineNum++
		line := scanner.Text()
		if len(line) == 0 || line[0] == '#' {
			continue
		}

		k, v, _ := strings.Cut(line, " = ")
		switch k {
		case "ModInv":
			ModInv = v
		case "A":
			A = v
		case "M":
			M = v

			t.Run(fmt.Sprintf("line %d", lineNum), func(t *testing.T) {
				m, err := NewModulus(decodeHex(t, M))
				if err != nil {
					t.Skip("modulus <= 1")
				}
				a, err := NewNat().SetBytes(decodeHex(t, A), m)
				if err != nil {
					t.Fatal(err)
				}

				got, ok := NewNat().InverseVarTime(a, m)
				if !ok {
					t.Fatal("not invertible")
				}
				exp, err := NewNat().SetBytes(decodeHex(t, ModInv), m)
				if err != nil {
					t.Fatal(err)
				}
				if got.Equal(exp) != 1 {
					t.Errorf("%v != %v", got, exp)
				}
			})
		default:
			t.Fatalf("unknown key %q on line %d", k, lineNum)
		}
	}
	if err := scanner.Err(); err != nil {
		t.Fatal(err)
	}
}

func decodeHex(t *testing.T, s string) []byte {
	t.Helper()
	if len(s)%2 != 0 {
		s = "0" + s
	}
	b, err := hex.DecodeString(s)
	if err != nil {
		t.Fatalf("failed to decode hex %q: %v", s, err)
	}
	return b
}
