// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package mlkem implements the quantum-resistant key encapsulation method
// ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203].
//
// [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203
package mlkem

// This package targets security, correctness, simplicity, readability, and
// reviewability as its primary goals. All critical operations are performed in
// constant time.
//
// Variable and function names, as well as code layout, are selected to
// facilitate reviewing the implementation against the NIST FIPS 203 document.
//
// Reviewers unfamiliar with polynomials or linear algebra might find the
// background at https://words.filippo.io/kyber-math/ useful.
//
// This file implements the recommended parameter set ML-KEM-768. The ML-KEM-1024
// parameter set implementation is auto-generated from this file.
//
//go:generate go run generate1024.go -input mlkem768.go -output mlkem1024.go

import (
	"bytes"
	"crypto/internal/fips140"
	"crypto/internal/fips140/drbg"
	"crypto/internal/fips140/sha3"
	"crypto/internal/fips140/subtle"
	"errors"
)

const (
	// ML-KEM global constants.
	n = 256
	q = 3329

	// encodingSizeX is the byte size of a ringElement or nttElement encoded
	// by ByteEncode_X (FIPS 203, Algorithm 5).
	encodingSize12 = n * 12 / 8
	encodingSize11 = n * 11 / 8
	encodingSize10 = n * 10 / 8
	encodingSize5  = n * 5 / 8
	encodingSize4  = n * 4 / 8
	encodingSize1  = n * 1 / 8

	messageSize = encodingSize1

	SharedKeySize = 32
	SeedSize      = 32 + 32
)

// ML-KEM-768 parameters.
const (
	k = 3

	CiphertextSize768       = k*encodingSize10 + encodingSize4
	EncapsulationKeySize768 = k*encodingSize12 + 32
	decapsulationKeySize768 = k*encodingSize12 + EncapsulationKeySize768 + 32 + 32
)

// ML-KEM-1024 parameters.
const (
	k1024 = 4

	CiphertextSize1024       = k1024*encodingSize11 + encodingSize5
	EncapsulationKeySize1024 = k1024*encodingSize12 + 32
	decapsulationKeySize1024 = k1024*encodingSize12 + EncapsulationKeySize1024 + 32 + 32
)

// A DecapsulationKey768 is the secret key used to decapsulate a shared key from a
// ciphertext. It includes various precomputed values.
type DecapsulationKey768 struct {
	d [32]byte // decapsulation key seed
	z [32]byte // implicit rejection sampling seed

	ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key
	h [32]byte // H(ek), stored for ML-KEM.Decaps_internal

	encryptionKey
	decryptionKey
}

// Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
//
// The decapsulation key must be kept secret.
func (dk *DecapsulationKey768) Bytes() []byte {
	var b [SeedSize]byte
	copy(b[:], dk.d[:])
	copy(b[32:], dk.z[:])
	return b[:]
}

// TestingOnlyExpandedBytes768 returns the decapsulation key as a byte slice
// using the full expanded NIST encoding.
//
// This should only be used for ACVP testing. For all other purposes prefer
// the Bytes method that returns the (much smaller) seed.
func TestingOnlyExpandedBytes768(dk *DecapsulationKey768) []byte {
	b := make([]byte, 0, decapsulationKeySize768)

	// ByteEncode₁₂(s)
	for i := range dk.s {
		b = polyByteEncode(b, dk.s[i])
	}

	// ByteEncode₁₂(t) || ρ
	for i := range dk.t {
		b = polyByteEncode(b, dk.t[i])
	}
	b = append(b, dk.ρ[:]...)

	// H(ek) || z
	b = append(b, dk.h[:]...)
	b = append(b, dk.z[:]...)

	return b
}

// EncapsulationKey returns the public encapsulation key necessary to produce
// ciphertexts.
func (dk *DecapsulationKey768) EncapsulationKey() *EncapsulationKey768 {
	return &EncapsulationKey768{
		ρ:             dk.ρ,
		h:             dk.h,
		encryptionKey: dk.encryptionKey,
	}
}

// An EncapsulationKey768 is the public key used to produce ciphertexts to be
// decapsulated by the corresponding [DecapsulationKey768].
type EncapsulationKey768 struct {
	ρ [32]byte // sampleNTT seed for A
	h [32]byte // H(ek)
	encryptionKey
}

// Bytes returns the encapsulation key as a byte slice.
func (ek *EncapsulationKey768) Bytes() []byte {
	// The actual logic is in a separate function to outline this allocation.
	b := make([]byte, 0, EncapsulationKeySize768)
	return ek.bytes(b)
}

func (ek *EncapsulationKey768) bytes(b []byte) []byte {
	for i := range ek.t {
		b = polyByteEncode(b, ek.t[i])
	}
	b = append(b, ek.ρ[:]...)
	return b
}

// encryptionKey is the parsed and expanded form of a PKE encryption key.
type encryptionKey struct {
	t [k]nttElement     // ByteDecode₁₂(ek[:384k])
	a [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
}

// decryptionKey is the parsed and expanded form of a PKE decryption key.
type decryptionKey struct {
	s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize])
}

// GenerateKey768 generates a new decapsulation key, drawing random bytes from
// a DRBG. The decapsulation key must be kept secret.
func GenerateKey768() (*DecapsulationKey768, error) {
	// The actual logic is in a separate function to outline this allocation.
	dk := &DecapsulationKey768{}
	return generateKey(dk)
}

func generateKey(dk *DecapsulationKey768) (*DecapsulationKey768, error) {
	var d [32]byte
	drbg.Read(d[:])
	var z [32]byte
	drbg.Read(z[:])
	kemKeyGen(dk, &d, &z)
	if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
		// This clearly can't happen, but FIPS 140-3 requires us to check.
		panic(err)
	}
	fips140.RecordApproved()
	return dk, nil
}

// GenerateKeyInternal768 is a derandomized version of GenerateKey768,
// exclusively for use in tests.
func GenerateKeyInternal768(d, z *[32]byte) *DecapsulationKey768 {
	dk := &DecapsulationKey768{}
	kemKeyGen(dk, d, z)
	return dk
}

// NewDecapsulationKey768 parses a decapsulation key from a 64-byte
// seed in the "d || z" form. The seed must be uniformly random.
func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) {
	// The actual logic is in a separate function to outline this allocation.
	dk := &DecapsulationKey768{}
	return newKeyFromSeed(dk, seed)
}

func newKeyFromSeed(dk *DecapsulationKey768, seed []byte) (*DecapsulationKey768, error) {
	if len(seed) != SeedSize {
		return nil, errors.New("mlkem: invalid seed length")
	}
	d := (*[32]byte)(seed[:32])
	z := (*[32]byte)(seed[32:])
	kemKeyGen(dk, d, z)
	if err := fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) }); err != nil {
		// This clearly can't happen, but FIPS 140-3 requires us to check.
		panic(err)
	}
	fips140.RecordApproved()
	return dk, nil
}

// TestingOnlyNewDecapsulationKey768 parses a decapsulation key from its expanded NIST format.
//
// Bytes() must not be called on the returned key, as it will not produce the
// original seed.
//
// This function should only be used for ACVP testing. Prefer NewDecapsulationKey768 for all
// other purposes.
func TestingOnlyNewDecapsulationKey768(b []byte) (*DecapsulationKey768, error) {
	if len(b) != decapsulationKeySize768 {
		return nil, errors.New("mlkem: invalid NIST decapsulation key length")
	}

	dk := &DecapsulationKey768{}
	for i := range dk.s {
		var err error
		dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
		if err != nil {
			return nil, errors.New("mlkem: invalid secret key encoding")
		}
		b = b[encodingSize12:]
	}

	ek, err := NewEncapsulationKey768(b[:EncapsulationKeySize768])
	if err != nil {
		return nil, err
	}
	dk.ρ = ek.ρ
	dk.h = ek.h
	dk.encryptionKey = ek.encryptionKey
	b = b[EncapsulationKeySize768:]

	if !bytes.Equal(dk.h[:], b[:32]) {
		return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
	}
	b = b[32:]

	copy(dk.z[:], b)

	// Generate a random d value for use in Bytes(). This is a safety mechanism
	// that avoids returning a broken key vs a random key if this function is
	// called in contravention of the TestingOnlyNewDecapsulationKey768 function
	// comment advising against it.
	drbg.Read(dk.d[:])

	return dk, nil
}

// kemKeyGen generates a decapsulation key.
//
// It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
// K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
// copies and allocations.
func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) {
	dk.d = *d
	dk.z = *z

	g := sha3.New512()
	g.Write(d[:])
	g.Write([]byte{k}) // Module dimension as a domain separator.
	G := g.Sum(make([]byte, 0, 64))
	ρ, σ := G[:32], G[32:]
	dk.ρ = [32]byte(ρ)

	A := &dk.a
	for i := byte(0); i < k; i++ {
		for j := byte(0); j < k; j++ {
			A[i*k+j] = sampleNTT(ρ, j, i)
		}
	}

	var N byte
	s := &dk.s
	for i := range s {
		s[i] = ntt(samplePolyCBD(σ, N))
		N++
	}
	e := make([]nttElement, k)
	for i := range e {
		e[i] = ntt(samplePolyCBD(σ, N))
		N++
	}

	t := &dk.t
	for i := range t { // t = A ◦ s + e
		t[i] = e[i]
		for j := range s {
			t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j]))
		}
	}

	H := sha3.New256()
	ek := dk.EncapsulationKey().Bytes()
	H.Write(ek)
	H.Sum(dk.h[:0])
}

// kemPCT performs a Pairwise Consistency Test per FIPS 140-3 IG 10.3.A
// Additional Comment 1: "For key pairs generated for use with approved KEMs in
// FIPS 203, the PCT shall consist of applying the encapsulation key ek to
// encapsulate a shared secret K leading to ciphertext c, and then applying
// decapsulation key dk to retrieve the same shared secret K. The PCT passes if
// the two shared secret K values are equal. The PCT shall be performed either
// when keys are generated/imported, prior to the first exportation, or prior to
// the first operational use (if not exported before the first use)."
func kemPCT(dk *DecapsulationKey768) error {
	ek := dk.EncapsulationKey()
	K, c := ek.Encapsulate()
	K1, err := dk.Decapsulate(c)
	if err != nil {
		return err
	}
	if subtle.ConstantTimeCompare(K, K1) != 1 {
		return errors.New("mlkem: PCT failed")
	}
	return nil
}

// Encapsulate generates a shared key and an associated ciphertext from an
// encapsulation key, drawing random bytes from a DRBG.
//
// The shared key must be kept secret.
func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
	// The actual logic is in a separate function to outline this allocation.
	var cc [CiphertextSize768]byte
	return ek.encapsulate(&cc)
}

func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (sharedKey, ciphertext []byte) {
	var m [messageSize]byte
	drbg.Read(m[:])
	// Note that the modulus check (step 2 of the encapsulation key check from
	// FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK.
	fips140.RecordApproved()
	return kemEncaps(cc, ek, &m)
}

// EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
// use in tests.
func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
	cc := &[CiphertextSize768]byte{}
	return kemEncaps(cc, ek, m)
}

// kemEncaps generates a shared key and an associated ciphertext.
//
// It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (K, c []byte) {
	g := sha3.New512()
	g.Write(m[:])
	g.Write(ek.h[:])
	G := g.Sum(nil)
	K, r := G[:SharedKeySize], G[SharedKeySize:]
	c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
	return K, c
}

// NewEncapsulationKey768 parses an encapsulation key from its encoded form.
// If the encapsulation key is not valid, NewEncapsulationKey768 returns an error.
func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, error) {
	// The actual logic is in a separate function to outline this allocation.
	ek := &EncapsulationKey768{}
	return parseEK(ek, encapsulationKey)
}

// parseEK parses an encryption key from its encoded form.
//
// It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
// Algorithm 14.
func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
	if len(ekPKE) != EncapsulationKeySize768 {
		return nil, errors.New("mlkem: invalid encapsulation key length")
	}

	h := sha3.New256()
	h.Write(ekPKE)
	h.Sum(ek.h[:0])

	for i := range ek.t {
		var err error
		ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
		if err != nil {
			return nil, err
		}
		ekPKE = ekPKE[encodingSize12:]
	}
	copy(ek.ρ[:], ekPKE)

	for i := byte(0); i < k; i++ {
		for j := byte(0); j < k; j++ {
			ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i)
		}
	}

	return ek, nil
}

// pkeEncrypt encrypt a plaintext message.
//
// It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
// computation of t and AT is done in parseEK.
func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
	var N byte
	r, e1 := make([]nttElement, k), make([]ringElement, k)
	for i := range r {
		r[i] = ntt(samplePolyCBD(rnd, N))
		N++
	}
	for i := range e1 {
		e1[i] = samplePolyCBD(rnd, N)
		N++
	}
	e2 := samplePolyCBD(rnd, N)

	u := make([]ringElement, k) // NTT⁻¹(AT ◦ r) + e1
	for i := range u {
		u[i] = e1[i]
		for j := range r {
			// Note that i and j are inverted, as we need the transposed of A.
			u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k+i], r[j])))
		}
	}

	μ := ringDecodeAndDecompress1(m)

	var vNTT nttElement // t⊺ ◦ r
	for i := range ex.t {
		vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
	}
	v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)

	c := cc[:0]
	for _, f := range u {
		c = ringCompressAndEncode10(c, f)
	}
	c = ringCompressAndEncode4(c, v)

	return c
}

// Decapsulate generates a shared key from a ciphertext and a decapsulation key.
// If the ciphertext is not valid, Decapsulate returns an error.
//
// The shared key must be kept secret.
func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
	if len(ciphertext) != CiphertextSize768 {
		return nil, errors.New("mlkem: invalid ciphertext length")
	}
	c := (*[CiphertextSize768]byte)(ciphertext)
	// Note that the hash check (step 3 of the decapsulation input check from
	// FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
	// validly generated by ML-KEM.KeyGen_internal.
	return kemDecaps(dk, c), nil
}

// kemDecaps produces a shared key from a ciphertext.
//
// It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
func kemDecaps(dk *DecapsulationKey768, c *[CiphertextSize768]byte) (K []byte) {
	fips140.RecordApproved()
	m := pkeDecrypt(&dk.decryptionKey, c)
	g := sha3.New512()
	g.Write(m[:])
	g.Write(dk.h[:])
	G := g.Sum(make([]byte, 0, 64))
	Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
	J := sha3.NewShake256()
	J.Write(dk.z[:])
	J.Write(c[:])
	Kout := make([]byte, SharedKeySize)
	J.Read(Kout)
	var cc [CiphertextSize768]byte
	c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r)

	subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
	return Kout
}

// pkeDecrypt decrypts a ciphertext.
//
// It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
// although s is retained from kemKeyGen.
func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize768]byte) []byte {
	u := make([]ringElement, k)
	for i := range u {
		b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
		u[i] = ringDecodeAndDecompress10(b)
	}

	b := (*[encodingSize4]byte)(c[encodingSize10*k:])
	v := ringDecodeAndDecompress4(b)

	var mask nttElement // s⊺ ◦ NTT(u)
	for i := range dx.s {
		mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
	}
	w := polySub(v, inverseNTT(mask))

	return ringCompressAndEncode1(nil, w)
}
