import torch
import triton
from torch.autograd import Function
from torch.amp import custom_fwd, custom_bwd
from prng import RomuTrio32
from nadd_fwd import nadd_fwd
from nadd_bwd import nadd_bwd_bit
from q_config import QBN, QConfig


class nadd_func(Function):
    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, x, x_bit, block_size: int, seed: int, is_uniform: bool):
        noisy_x, x_alpha = nadd_fwd(x, x_bit, block_size, seed, is_uniform)
        ctx.seed = seed
        ctx.block_size = block_size
        ctx.is_uniform = is_uniform
        ctx.save_for_backward(x_alpha, x_bit)
        return noisy_x

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_out):
        x_alpha, x_bit = ctx.saved_tensors

        grad_x = grad_out
        grad_bit = nadd_bwd_bit(
            grad_out, x_alpha, x_bit, ctx.seed, ctx.block_size, ctx.is_uniform
        )

        return grad_x, grad_bit, None, None, None


class Nadd(torch.nn.Module):
    def __init__(self, M, N, seed: int, config: QConfig):
        super().__init__()

        self.seed = seed
        self.config = config

        self.bit = torch.nn.parameter.Parameter(
            torch.ones(
                (
                    triton.cdiv(M, config.block_size),
                    triton.cdiv(N, config.block_size),
                ),
            ),
            requires_grad=True,
        )
        self.romu = RomuTrio32(seed)
        for _ in range(10):
            self.romu.next()

    def _bit(self):
        # bit == 1 --> init
        # bit == 0 --> target
        return self.config.target_bit + self.bit * (
            self.config.init_bit - self.config.target_bit
        )

    @torch.autocast(dtype=torch.bfloat16, device_type="cuda")
    def forward(self, x):
        if self.training:
            noisy_x = nadd_func.apply(
                x,
                self._bit(),
                self.config.block_size,
                self.romu.next(),
                self.config.is_discrete_uniform,
            )
        else:
            noisy_x = x
        return noisy_x


class nadd_func_diffq(Function):
    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, x, x_bit, x_alpha, block_size: int, seed: int):
        torch.manual_seed(seed)
        urand = torch.rand_like(x) - 0.5
        # urand = torch.randn_like(x) * 0.5
        x_bit_re = x_bit.reshape((x_bit.shape[0], x_bit.shape[1], 1, 1))
        blockwise = torch.pow(2, 1 - x_bit_re) * x_alpha
        broadcasted = (
            blockwise.broadcast_to(
                x_bit.shape[0], x_bit.shape[1], block_size, block_size
            )
            .transpose(1, 2)
            .reshape(x.shape)
        )

        noisy_x = x + broadcasted * urand
        ctx.seed = seed
        ctx.block_size = block_size
        ctx.save_for_backward(x_alpha, x_bit)
        return noisy_x

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_out):
        x_alpha, x_bit = ctx.saved_tensors

        grad_x = grad_out

        torch.manual_seed(ctx.seed)
        urand = torch.rand_like(grad_out) - 0.5
        # urand = torch.randn_like(grad_out) * 0.5

        # -ln2 * x_alpha * 2^{1-b_t} * block_sum(grad_out * urand)
        elemwise = grad_out * urand
        m, n = grad_out.shape
        blockwise = (
            elemwise.reshape(
                (
                    (m + ctx.block_size - 1) // ctx.block_size,
                    ctx.block_size,
                    (n + ctx.block_size - 1) // ctx.block_size,
                    ctx.block_size,
                )
            )
            .transpose(1, 2)
            .sum(axis=(2, 3))
        )

        minus_ln2 = -0.693147182464599609375  # -ln(2) in fp32

        grad_bit = (
            minus_ln2
            * x_alpha.reshape((x_alpha.shape[0], x_alpha.shape[1]))
            * torch.pow(2, 1 - x_bit)
            * blockwise
        )

        return grad_x, grad_bit, None, None, None


class NaddDiffQ(torch.nn.Module):
    def __init__(self, M, N, seed: int, config: QConfig):
        super().__init__()

        self.seed = seed
        self.config = config

        self.bit = torch.nn.parameter.Parameter(
            torch.ones(
                (
                    triton.cdiv(M, config.block_size),
                    triton.cdiv(N, config.block_size),
                ),
            ),
            requires_grad=True,
        )
        self.romu = RomuTrio32(seed)
        for _ in range(10):
            self.romu.next()

    def _bit(self):
        # bit == 1 --> init
        # bit == 0 --> target
        return self.config.target_bit + self.bit * (
            self.config.init_bit - self.config.target_bit
        )

    @torch.autocast(dtype=torch.bfloat16, device_type="cuda")
    def forward(self, x):
        if self.training:
            assert len(x.shape) == 2
            m, n = x.shape
            x_alpha = (
                x.reshape(
                    (
                        (m + self.config.block_size - 1)
                        // self.config.block_size,
                        self.config.block_size,
                        (n + self.config.block_size - 1)
                        // self.config.block_size,
                        self.config.block_size,
                    )
                )
                .transpose(1, 2)
                .abs()
                .max(axis=2, keepdims=True)
                .values.max(axis=3, keepdims=True)
                .values
            )
            noisy_x = nadd_func_diffq.apply(
                x,
                self._bit(),
                x_alpha,
                self.config.block_size,
                self.romu.next(),
            )
        else:
            noisy_x = x
        return noisy_x


class NaddX(torch.nn.Module):
    def __init__(self, N, seed: int, config: QConfig):
        super().__init__()

        self.seed = seed
        self.config = config

        self.bit = torch.nn.parameter.Parameter(
            torch.ones(
                (triton.cdiv(N, config.block_size),),
            ),
            requires_grad=True,
        )
        self.romu = RomuTrio32(seed)
        for _ in range(10):
            self.romu.next()

    def _bit(self):
        # bit == 1 --> init
        # bit == 0 --> target
        return self.config.target_bit + self.bit * (
            self.config.init_bit - self.config.target_bit
        )

    @torch.autocast(dtype=torch.bfloat16, device_type="cuda")
    def forward(self, x):
        if self.training:
            noisy_x = nadd_func.apply(
                x,
                self._bit(),
                self.config.block_size,
                self.romu.next(),
                self.config.is_discrete_uniform,
            )
        else:
            noisy_x = x
        return noisy_x


if __name__ == "__main__":
    from torch.nn import functional as F
    from prng import rand_4bit_fp

    is_discrete_uniform = True

    class inject_noise_gm(torch.nn.Module):
        def __init__(self, M, N, seed: int, config: QConfig):
            super().__init__()
            self.M = M
            self.N = N
            self.seed = seed
            self.config = config

            self.NBM = triton.cdiv(M, config.block_size)
            self.NBN = triton.cdiv(N, config.block_size)

            self.bit = torch.nn.parameter.Parameter(
                torch.ones((self.NBM, self.NBN)),
                requires_grad=True,
            )
            self.romu = RomuTrio32(seed)
            for _ in range(10):
                self.romu.next()

        @torch.autocast(dtype=torch.bfloat16, device_type="cuda")
        def forward(self, x):
            bit_calc = self.config.target_bit + self.bit * (
                self.config.init_bit - self.config.target_bit
            )
            bit_calc = torch.exp2(1 - bit_calc)
            block_size = self.config.block_size
            bit_re = (
                bit_calc.view((self.NBM, 1, self.NBN, 1))
                .broadcast_to((self.NBM, block_size, self.NBN, block_size))
                .reshape((self.NBM * block_size, self.NBN * block_size))
            )
            with torch.no_grad():
                noise = rand_4bit_fp(
                    self.M,
                    self.N,
                    self.romu.next(),
                    is_uniform=self.config.is_discrete_uniform,
                )
                absmax = (
                    x.abs()
                    .view((self.NBM, block_size, self.NBN, block_size))
                    .max(axis=1, keepdims=True)
                    .values.max(axis=-1, keepdims=True)
                    .values
                )
                alpha = absmax.broadcast_to(
                    (self.NBM, block_size, self.NBN, block_size)
                ).reshape(self.M, self.N)
            return x + noise * alpha * bit_re

    class inject_noise_gm_X(torch.nn.Module):
        def __init__(self, N, seed: int, config: QConfig):
            super().__init__()
            self.N = N
            self.seed = seed
            self.config = config

            self.NBN = triton.cdiv(N, config.block_size)

            self.bit = torch.nn.parameter.Parameter(
                torch.ones((self.NBN,)),
                requires_grad=True,
            )
            self.romu = RomuTrio32(seed)
            for _ in range(10):
                self.romu.next()

        @torch.autocast(dtype=torch.bfloat16, device_type="cuda")
        def forward(self, x):
            orig_shape = x.shape
            block_size = self.config.block_size

            if len(x.shape) == 1:
                x = x.view((1, -1))
            flat_x = x.flatten(0, -2)
            M = flat_x.shape[0]

            bit_calc = self.config.target_bit + self.bit * (
                self.config.init_bit - self.config.target_bit
            )
            bit_calc = torch.exp2(1 - bit_calc)
            bit_re = (
                bit_calc.view((1, self.NBN, 1))
                .broadcast_to((M, self.NBN, block_size))
                .reshape((M, self.NBN * block_size))
            )
            with torch.no_grad():
                noise = rand_4bit_fp(
                    M,
                    self.N,
                    self.romu.next(),
                    is_uniform=self.config.is_discrete_uniform,
                )
                absmax = (
                    x.abs()
                    .view((-1, block_size, self.NBN, block_size))
                    .max(axis=1, keepdims=True)
                    .values.max(axis=-1, keepdims=True)
                    .values
                )
                alpha = absmax.broadcast_to(
                    (absmax.shape[0], block_size, self.NBN, block_size)
                ).reshape(M, self.N)
            return (x + noise * alpha * bit_re).view(orig_shape)

    print("integrity check...")
    block_size = QBN
    romu = RomuTrio32(seed=42)
    for _ in range(10):
        romu.next()
    for m in [32, 128, 1024, 4096]:
        for n in [32, 128, 1024, 4096]:
            for num_bit in [4, 6, 8]:
                for activation in [True, False]:
                    print(f"{m}-{n}-{num_bit}bit, X={activation}", end="\t")
                    seed = romu.next()
                    if activation:
                        golden = inject_noise_gm_X(
                            n,
                            seed,
                            QConfig(
                                num_bit, is_discrete_uniform=is_discrete_uniform
                            ),
                        ).cuda()
                        ours = NaddX(
                            n,
                            seed,
                            QConfig(
                                num_bit, is_discrete_uniform=is_discrete_uniform
                            ),
                        ).cuda()
                    else:
                        golden = inject_noise_gm(
                            m,
                            n,
                            seed,
                            QConfig(
                                num_bit, is_discrete_uniform=is_discrete_uniform
                            ),
                        ).cuda()
                        ours = Nadd(
                            m,
                            n,
                            seed,
                            QConfig(
                                num_bit, is_discrete_uniform=is_discrete_uniform
                            ),
                        ).cuda()

                    # forward
                    t1 = torch.randn([m, n], device="cuda") * 32

                    out_golden = golden(t1)
                    out_ours = ours(t1)

                    forward_l1 = F.l1_loss(out_golden, out_ours).item()

                    print(f"fwd: {forward_l1}", end="\t")

                    # check backprop
                    loss_golden = (
                        F.l1_loss(out_golden, torch.zeros_like(out_golden))
                        * 128
                    )
                    loss_ours = (
                        F.l1_loss(out_ours, torch.zeros_like(out_ours)) * 128
                    )
                    loss_golden.backward()
                    loss_ours.backward()

                    grad_bit_l1 = F.l1_loss(
                        golden.bit.grad, ours.bit.grad
                    ).item()
                    print(f"bwd: {grad_bit_l1}")
                    if forward_l1 > 1e-5:
                        print(f"golden fwd: {out_golden}")
                        print(f"ours fwd: {out_ours}")
                    if grad_bit_l1 > 1e-5:
                        print(f"golden bwd: {golden.bit.grad}")
                        print(f"ours bwd: {ours.bit.grad}")
                    golden.zero_grad()
                    ours.zero_grad()

    print("benchmark ...")

    def get_bench_configs():
        configs = []

        line_vals = [
            f"{qtype}-{dtype}"
            for qtype in ["golden", "ours"]
            for dtype in ["act", "wgt"]
        ]
        styles = [
            (color, style) for color in ["red", "cyan"] for style in ["-", ":"]
        ]
        line_names = line_vals

        configs.append(
            triton.testing.Benchmark(
                x_names=[
                    "M",
                    "N",
                ],  # Argument names to use as an x-axis for the plot
                x_vals=[128 * i for i in range(1, 8)]
                + [256 * i for i in range(4, 8)]
                + [512 * i for i in range(4, 8)]
                + [1024 * i for i in range(4, 8)]
                + [2048 * i for i in range(4, 9)],
                y_log=False,
                x_log=False,
                line_arg="provider",
                line_vals=line_vals,
                line_names=line_names,
                styles=styles,
                ylabel="Gelem/s",  # Label name for the y-axis
                args={},
                plot_name="nadd",
            )
        )
        return configs

    @triton.testing.perf_report(get_bench_configs())
    def benchmark(M, N, provider):
        q_part = provider.split("-")[0]
        d_part = provider.split("-")[-1]

        print(f"(M={M}, N={N}, {q_part}, {d_part}), {provider}")

        dut = None
        if q_part == "golden" and d_part == "wgt":
            dut = inject_noise_gm(
                M,
                N,
                seed=42,
                config=QConfig(8.0, is_discrete_uniform=is_discrete_uniform),
            ).cuda()
        elif q_part == "golden" and d_part == "act":
            dut = inject_noise_gm_X(
                N,
                seed=42,
                config=QConfig(8.0, is_discrete_uniform=is_discrete_uniform),
            ).cuda()
        elif q_part == "ours" and d_part == "wgt":
            dut = Nadd(
                M,
                N,
                seed=42,
                config=QConfig(8.0, is_discrete_uniform=is_discrete_uniform),
            ).cuda()
        elif q_part == "ours" and d_part == "act":
            dut = NaddX(
                N,
                seed=42,
                config=QConfig(8.0, is_discrete_uniform=is_discrete_uniform),
            ).cuda()
        else:
            raise ValueError(q_part)

        quantiles = [0.5, 0.2, 0.8]
        t1 = torch.randn((M, N), device="cuda")
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: dut(t1),
            quantiles=quantiles,
            rep=1000,
        )
        perf = lambda ms: M * N * 1e-9 / (ms * 1e-3)
        return perf(ms), perf(max_ms), perf(min_ms)

    benchmark.run(save_path="./")
