import torch
import triton
from torch.autograd import Function
from torch.amp import custom_fwd, custom_bwd
from gelu_fwd import gelu_fwd
from gelu_bwd import gelu_bwd_bit, gelu_bwd_x
from prng import RomuTrio32
from q_config import QConfig


class gelu_func(Function):
    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, x, out_bit, block_size: int, seed: int):
        noisy_x, out_alpha = gelu_fwd(x, out_bit, block_size, seed)
        ctx.seed = seed
        ctx.block_size = block_size
        ctx.save_for_backward(x, out_alpha, out_bit)
        return noisy_x

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_out):
        x, out_alpha, out_bit = ctx.saved_tensors

        grad_x = gelu_bwd_x(grad_out, x)
        grad_bit = gelu_bwd_bit(
            grad_out, out_alpha, out_bit, ctx.seed, ctx.block_size
        )

        return grad_x, grad_bit, None, None


class Gelu(torch.nn.Module):
    def __init__(
        self,
        N,
        seed: int,
        config: QConfig,
    ):
        super().__init__()

        self.N = N
        self.config = config

        self.NBN = triton.cdiv(N, config.block_size)
        self.bit = torch.nn.parameter.Parameter(
            torch.ones((triton.cdiv(N, config.block_size),)) * config.init_bit,
            requires_grad=True,
        )
        self.romu = RomuTrio32(seed)
        for _ in range(10):
            self.romu.next()
        self.prng_freeze = torch.nn.parameter.Parameter(
            torch.zeros((1,), dtype=torch.bool),
            requires_grad=False,
        )

    def forward(self, x):
        block_size = self.config.block_size
        noisy_x = gelu_func.apply(
            x, self.bit, block_size, self.romu.next(self.prng_freeze)
        )
        return noisy_x


if __name__ == "__main__":
    from torch.nn import functional as F
    from prng import rand_4bit_fp

    class Gelu_gm(torch.nn.Module):
        def __init__(self, N, seed: int, config: QConfig):
            super().__init__()

            self.N = N
            self.config = config

            self.NBN = triton.cdiv(N, config.block_size)
            self.bit = torch.nn.parameter.Parameter(
                torch.ones((triton.cdiv(N, config.block_size),))
                * config.init_bit,
                requires_grad=True,
            )
            self.romu = RomuTrio32(seed)
            for _ in range(10):
                self.romu.next()

        def forward(self, x):
            orig_shape = x.shape
            block_size = self.config.block_size

            if len(x.shape) == 1:
                x = x.view((1, -1))
            flat_x = x.flatten(0, -2)

            M, N = flat_x.shape

            gelu_x = F.gelu(flat_x)
            with torch.no_grad():
                noise = rand_4bit_fp(M, N, seed=self.romu.next())
                absmax = (
                    gelu_x.abs()
                    .view((-1, block_size, self.NBN, block_size))
                    .max(axis=1, keepdims=True)
                    .values.max(axis=-1, keepdims=True)
                    .values
                )
                alpha = absmax.broadcast_to(
                    (absmax.shape[0], block_size, self.NBN, block_size)
                ).reshape(M, self.N)
            bit_calc = torch.exp2(1 - self.bit)  # (NBN,)
            bit_re = (
                bit_calc.view((1, self.NBN, 1))
                .broadcast_to((M, self.NBN, block_size))
                .reshape((M, self.NBN * block_size))
            )
            noisy_gelu_x = gelu_x + noise * alpha * bit_re

            return noisy_gelu_x.view(orig_shape)

    print("verify ...")
    romu = RomuTrio32(seed=42)
    for _ in range(10):
        romu.next()
    for m in [32, 128, 1024, 4096]:
        for n in [32, 128, 1024, 4096]:
            for num_bit in [4, 6, 8]:
                print(f"{m}-{n}-{num_bit}bit", end="\t")
                seed = romu.next()
                golden = Gelu_gm(n, seed, QConfig(num_bit)).cuda()
                ours = Gelu(n, seed, QConfig(num_bit)).cuda()

                # forward
                t1 = torch.randn([m, n], device="cuda") * 32
                t2 = t1.clone().detach()

                t1 = torch.nn.parameter.Parameter(t1, requires_grad=True)
                t2 = torch.nn.parameter.Parameter(t2, requires_grad=True)

                out_golden = golden(t1)
                out_ours = ours(t2)

                forward_l1 = F.l1_loss(out_golden, out_ours).item()

                print(f"fwd: {forward_l1}", end="\t")

                # check backprop
                loss_golden = F.l1_loss(
                    out_golden, torch.zeros_like(out_golden)
                )
                loss_ours = F.l1_loss(out_ours, torch.zeros_like(out_ours))
                loss_golden.backward()
                loss_ours.backward()

                grad_bit_l1 = F.l1_loss(golden.bit.grad, ours.bit.grad).item()
                grad_x_l1 = F.l1_loss(t1.grad, t2.grad).item()
                print(f"bwd_bit: {grad_bit_l1}", end="\t")
                print(f"bwd_x: {grad_x_l1}")
                if forward_l1 > 1e-5:
                    print(f"golden fwd: {out_golden}")
                    print(f"ours fwd: {out_ours}")
                if grad_bit_l1 > 1e-5:
                    print(f"golden bwd_bit: {golden.bit.grad}")
                    print(f"ours bwd_bit: {ours.bit.grad}")
                if grad_x_l1 > 1e-5:
                    print(f"golden bwd_x: {t1.grad}")
                    print(f"ours bwd_x: {t2.grad}")

                golden.zero_grad()
                ours.zero_grad()

    print("benchmark ...")

    def get_bench_configs():
        configs = []

        line_vals = [f"{qtype}" for qtype in ["golden", "ours"]]
        styles = [
            (color, style) for color in ["red", "cyan"] for style in ["-"]
        ]
        line_names = line_vals

        configs.append(
            triton.testing.Benchmark(
                x_names=[
                    "M",
                    "N",
                ],  # Argument names to use as an x-axis for the plot
                x_vals=[128 * i for i in range(1, 8)]
                + [256 * i for i in range(4, 8)]
                + [512 * i for i in range(4, 8)]
                + [1024 * i for i in range(4, 8)]
                + [2048 * i for i in range(4, 9)],
                y_log=False,
                x_log=False,
                line_arg="provider",
                line_vals=line_vals,
                line_names=line_names,
                styles=styles,
                ylabel="Gelem/s",  # Label name for the y-axis
                args={},
                plot_name="gelu",
            )
        )
        return configs

    @triton.testing.perf_report(get_bench_configs())
    def benchmark(M, N, provider):
        q_part = provider.split("-")[0]

        print(f"(M={M}, N={N}, {q_part}), {provider}")

        dut = None
        if q_part == "golden":
            dut = Gelu_gm(N, seed=42, config=QConfig(init_bit=8.0)).cuda()
        elif q_part == "ours":
            dut = Gelu(N, seed=42, config=QConfig(init_bit=8.0)).cuda()
        else:
            raise ValueError(q_part)

        quantiles = [0.5, 0.2, 0.8]
        t1 = torch.randn((M, N), device="cuda")
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: dut(t1),
            quantiles=quantiles,
            rep=1000,
        )
        perf = lambda ms: M * N * 1e-9 / (ms * 1e-3)
        return perf(ms), perf(max_ms), perf(min_ms)

    benchmark.run(save_path="./")
