// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build ignore

// Generate Go assembly for XORing CTR output to n blocks at once with one key.
package main

import (
	"fmt"
	"os"
	"strings"
	"text/template"
)

// First registers in their groups.
const (
	blockOffset    = 0
	roundKeyOffset = 8
	dstOffset      = 23
)

var tmplArm64Str = `
// Code generated by ctr_arm64_gen.go. DO NOT EDIT.

//go:build !purego

#include "textflag.h"

#define NR R9
#define XK R10
#define DST R11
#define SRC R12
#define IV_LOW_LE R16
#define IV_HIGH_LE R17
#define IV_LOW_BE R19
#define IV_HIGH_BE R20

// V0.B16 - V7.B16 are for blocks (<=8). See BLOCK_OFFSET.
// V8.B16 - V22.B16 are for <=15 round keys (<=15). See ROUND_KEY_OFFSET.
// V23.B16 - V30.B16 are for destinations (<=8). See DST_OFFSET.

{{define "load_keys"}}
	{{- range regs_batches (round_key_reg $.FirstKey) $.NKeys }}
		VLD1.P {{ .Size }}(XK), [{{ .Regs }}]
	{{- end }}
{{ end }}

{{define "enc"}}
	{{ range $i := xrange $.N -}}
		AESE V{{ round_key_reg $.Key}}.B16, V{{ block_reg $i }}.B16
		{{- if $.WithMc }}
			AESMC V{{ block_reg $i }}.B16, V{{ block_reg $i }}.B16
		{{- end }}
	{{ end }}
{{ end }}

{{ range $N := $.Sizes }}
// func ctrBlocks{{$N}}Asm(nr int, xk *[60]uint32, dst *[{{$N}}*16]byte, src *[{{$N}}*16]byte, ivlo uint64, ivhi uint64)
TEXT ·ctrBlocks{{ $N }}Asm(SB),NOSPLIT,$0
	MOVD nr+0(FP), NR
	MOVD xk+8(FP), XK
	MOVD dst+16(FP), DST
	MOVD src+24(FP), SRC
	MOVD ivlo+32(FP), IV_LOW_LE
	MOVD ivhi+40(FP), IV_HIGH_LE

	{{/* Prepare plain from IV and blockIndex. */}}

	{{/* Copy to plaintext registers. */}}
	{{ range $i := xrange $N }}
		REV IV_LOW_LE, IV_LOW_BE
		REV IV_HIGH_LE, IV_HIGH_BE
		{{- /* https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/MOV--vector--from-general- */}}
		VMOV IV_LOW_BE, V{{ block_reg $i }}.D[1]
		VMOV IV_HIGH_BE, V{{ block_reg $i }}.D[0]
		{{- if ne (add $i 1) $N }}
			ADDS $1, IV_LOW_LE
			ADC $0, IV_HIGH_LE
		{{ end }}
	{{ end }}

	{{/* Num rounds branching. */}}
	CMP $12, NR
	BLT Lenc128
	BEQ Lenc192

	{{/* 2 extra rounds for 256-bit keys. */}}
	Lenc256:
	{{- template "load_keys" (load_keys_args 0 2) }}
	{{- template "enc" (enc_args 0 $N true) }}
	{{- template "enc" (enc_args 1 $N true) }}

	{{/* 2 extra rounds for 192-bit keys. */}}
	Lenc192:
	{{- template "load_keys" (load_keys_args 2 2) }}
	{{- template "enc" (enc_args 2 $N true) }}
	{{- template "enc" (enc_args 3 $N true) }}

	{{/* 10 rounds for 128-bit (with special handling for final). */}}
	Lenc128:
	{{- template "load_keys" (load_keys_args 4 11) }}
	{{- range $r := xrange 9 }}
		{{- template "enc" (enc_args (add $r 4) $N true) }}
	{{ end }}
	{{ template "enc" (enc_args 13 $N false) }}

	{{/* We need to XOR blocks with the last round key (key 14, register V22). */}}
	{{ range $i := xrange $N }}
		VEOR V{{ block_reg $i }}.B16, V{{ round_key_reg 14 }}.B16, V{{ block_reg $i }}.B16
	{{- end }}

	{{/* XOR results to destination. */}}
	{{- range regs_batches $.DstOffset $N }}
		VLD1.P {{ .Size }}(SRC), [{{ .Regs }}]
	{{- end }}
	{{- range $i := xrange $N }}
		VEOR V{{ add $.DstOffset $i }}.B16, V{{ block_reg $i }}.B16, V{{ add $.DstOffset $i }}.B16
	{{- end }}
	{{- range regs_batches $.DstOffset $N }}
		VST1.P [{{ .Regs }}], {{ .Size }}(DST)
	{{- end }}

	RET
{{ end }}
`

func main() {
	type Params struct {
		DstOffset int
		Sizes     []int
	}

	params := Params{
		DstOffset: dstOffset,
		Sizes:     []int{1, 2, 4, 8},
	}

	type RegsBatch struct {
		Size int
		Regs string // Comma-separated list of registers.
	}

	type LoadKeysArgs struct {
		FirstKey int
		NKeys    int
	}

	type EncArgs struct {
		Key    int
		N      int
		WithMc bool
	}

	funcs := template.FuncMap{
		"add": func(a, b int) int {
			return a + b
		},
		"xrange": func(n int) []int {
			result := make([]int, n)
			for i := 0; i < n; i++ {
				result[i] = i
			}
			return result
		},
		"block_reg": func(block int) int {
			return blockOffset + block
		},
		"round_key_reg": func(key int) int {
			return roundKeyOffset + key
		},
		"regs_batches": func(firstReg, nregs int) []RegsBatch {
			result := make([]RegsBatch, 0)
			for nregs != 0 {
				batch := 4
				if nregs < batch {
					batch = nregs
				}
				regsList := make([]string, 0, batch)
				for j := firstReg; j < firstReg+batch; j++ {
					regsList = append(regsList, fmt.Sprintf("V%d.B16", j))
				}
				result = append(result, RegsBatch{
					Size: 16 * batch,
					Regs: strings.Join(regsList, ", "),
				})
				nregs -= batch
				firstReg += batch
			}
			return result
		},
		"enc_args": func(key, n int, withMc bool) EncArgs {
			return EncArgs{
				Key:    key,
				N:      n,
				WithMc: withMc,
			}
		},
		"load_keys_args": func(firstKey, nkeys int) LoadKeysArgs {
			return LoadKeysArgs{
				FirstKey: firstKey,
				NKeys:    nkeys,
			}
		},
	}

	var tmpl = template.Must(template.New("ctr_arm64").Funcs(funcs).Parse(tmplArm64Str))

	if err := tmpl.Execute(os.Stdout, params); err != nil {
		panic(err)
	}
}
