import torch
import triton
from prng import RomuTrio32
from nadd import nadd_func
from q_config import QConfig
from torch.nn import functional as F
from nadd import Nadd, NaddDiffQ

import sys
import os
import inspect

currentdir = os.path.dirname(
    os.path.abspath(inspect.getfile(inspect.currentframe()))
)
pdir = os.path.dirname(currentdir)
sys.path.insert(0, pdir)

import mx


class MxQ(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        mx_specs = dict(
            scale_bits=8,
            shared_exp_method="max",
            mx_flush_fp32_subnorms=True,
            block_size=32,
            custom_cuda=False,
        )
        return mx.mx_ops.quantize_mx_op(
            x,
            mx_specs,
            elem_format="fp8_e3m4",
            block_size=32,
            axes=[-1, -2],
            round="nearest",
        )

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class Linear(torch.nn.Linear):
    def __init__(
        self,
        in_features,
        out_features,
        seed: int,
        wgt_config: QConfig,
        **kwargs,
    ):
        super().__init__(in_features, out_features, **kwargs)

        self.seed = seed
        self.wgt_config = wgt_config

        self.romu = RomuTrio32(seed)
        for _ in range(10):
            self.romu.next()

        if wgt_config is None:
            self.wgt_sampler = torch.nn.Identity()
        elif wgt_config.is_diffq:
            self.wgt_sampler = NaddDiffQ(
                out_features, in_features, self.romu.next(), wgt_config
            )
        else:
            self.wgt_sampler = Nadd(
                out_features, in_features, self.romu.next(), wgt_config
            )
        self.mx_specs = dict(
            scale_bits=8,
            shared_exp_method="max",
            mx_flush_fp32_subnorms=True,
            block_size=32,
            custom_cuda=False,
        )

    @torch.autocast(dtype=torch.bfloat16, device_type="cuda")
    def forward(self, act):
        # inject noise on weight
        wgt = self.wgt_sampler(self.weight)

        # quantize `wgt` with MXFP
        wgt = MxQ.apply(wgt)

        linear_out = F.linear(act, wgt, self.bias)
        return linear_out


if __name__ == "__main__":
    from prng import rand_4bit_fp
    from torch.nn import functional as F

    class Linear_gm(torch.nn.Linear):
        def __init__(
            self,
            in_features: int,
            out_features: int,
            seed: int,
            act_config: QConfig,
            wgt_config: QConfig,
            **kwargs,
        ):
            super().__init__(in_features, out_features, **kwargs)

            self.seed = seed
            self.act_config = act_config
            self.wgt_config = wgt_config

            self.romu = RomuTrio32(seed)
            for _ in range(10):
                self.romu.next()

            self.out_b = None
            self.wgt_b = None

            if act_config is not None:
                self.out_b = torch.nn.parameter.Parameter(
                    torch.ones(
                        (triton.cdiv(out_features, act_config.block_size),),
                        dtype=torch.float32,
                        device=self.weight.device,
                    )
                    * act_config.init_bit,
                    requires_grad=True,
                )
            if wgt_config is not None:
                self.wgt_b = torch.nn.parameter.Parameter(
                    torch.ones(
                        (
                            triton.cdiv(out_features, wgt_config.block_size),
                            triton.cdiv(in_features, wgt_config.block_size),
                        ),
                        dtype=torch.float32,
                        device=self.weight.device,
                    )
                    * wgt_config.init_bit,
                    requires_grad=True,
                )

        @torch.autocast(dtype=torch.bfloat16, device_type="cuda")
        def forward(self, act):
            orig_shape = act.shape
            if len(act.shape) == 1:
                act = act.view((1, -1))
            flatten_act = act.flatten(0, -2)
            M, K = flatten_act.shape  # ch, in_f
            N, K2 = self.weight.shape  # out_f, in_f
            assert K == K2

            # inject noise on W
            wgt = None
            if self.wgt_b is None:
                wgt = self.weight
            else:
                wgt = nadd_func.apply(
                    self.weight,
                    self.wgt_b,
                    self.wgt_config.block_size,
                    self.romu.next(),
                )

            linear_out = F.linear(flatten_act, wgt, self.bias)
            if self.out_b is None:
                return linear_out

            NBN = self.out_b.shape[0]
            assert NBN * self.act_config.block_size == N

            out_b_calc = torch.exp2(1 - self.out_b)
            out_b_re = (
                out_b_calc.view((1, NBN, 1))
                .broadcast_to((M, NBN, self.act_config.block_size))
                .reshape((M, N))
            )

            m_block = min(self.act_config.block_size, M)

            with torch.no_grad():
                noise = rand_4bit_fp(M, N, self.romu.next()).view((M, N))
                absmax = (
                    linear_out.abs()
                    .view((-1, m_block, NBN, self.act_config.block_size))
                    .max(dim=1, keepdims=True)
                    .values.max(dim=-1, keepdims=True)
                    .values
                )
                alpha = absmax.broadcast_to(
                    (absmax.shape[0], m_block, NBN, self.act_config.block_size)
                ).reshape((M, N))

            out = (linear_out + noise * alpha * out_b_re).to(linear_out.dtype)
            return out.view([*orig_shape[:-1], N])

    print("integrity check ...")
    seed = 42
    init_bit = 8.0
    for num_channel in [1, 32, 64, 256, 1024]:
        for out_features in [32, 256, 1024]:
            for in_features in [32, 256, 1024]:
                t1 = (
                    torch.randn(
                        [num_channel, in_features],
                        dtype=torch.bfloat16,
                        device="cuda",
                    )
                    * 8
                )
                wgt = (
                    torch.randn(
                        [out_features, in_features],
                        dtype=torch.bfloat16,
                        device="cuda",
                    )
                    * 8
                )
                for use_bias in [True, False]:
                    # matmul form : [B, M, K] @ [K, N] --> [B, M, N]
                    # linear form : [B, M, K] @ [N, K] --> [B, M, N]
                    # human form  : [B, M, in] @ [out, in] --> [B, M, out]
                    print(
                        f"{num_channel}-{out_features}-{in_features}",
                        end="",
                    )
                    print("-b" if use_bias else "", end="\t")
                    golden_module = Linear_gm(
                        in_features,
                        out_features,
                        bias=use_bias,
                        seed=seed,
                        act_config=QConfig(init_bit),
                        wgt_config=QConfig(init_bit),
                    ).cuda()
                    our_module = Linear(
                        in_features,
                        out_features,
                        bias=use_bias,
                        seed=seed,
                        act_config=QConfig(init_bit),
                        wgt_config=QConfig(init_bit),
                    ).cuda()
                    # weight & bias assignment
                    with torch.no_grad():
                        our_module.weight[:, :] = wgt[:, :]
                        golden_module.weight[:, :] = wgt[:, :]
                        if golden_module.bias is not None:
                            our_module.bias[:] = golden_module.bias[:]

                    golden = golden_module(t1)
                    ours = our_module(t1)

                    forward_l1_error = F.l1_loss(golden, ours).item()
                    print(f"fwd: {forward_l1_error}", end="\t")
                    if forward_l1_error > 0.3:
                        print(golden)
                        print(ours)

                    # run backward
                    loss_golden = (
                        F.l1_loss(golden, torch.zeros_like(golden)) * 128
                    )
                    loss_ours = F.l1_loss(ours, torch.zeros_like(ours)) * 128

                    loss_golden.backward()
                    loss_ours.backward()

                    # compare backward
                    w_grad_l1 = F.l1_loss(
                        golden_module.weight.grad, our_module.weight.grad
                    )
                    w_bit_grad_l1 = F.l1_loss(
                        golden_module.wgt_b.grad, our_module.wgt_b.grad
                    )
                    backward_l1 = [
                        f"{w_grad_l1.item()}",
                        f"{w_bit_grad_l1.item()}",
                    ]
                    if golden_module.out_b is not None:
                        bit_grad_l1 = F.l1_loss(
                            golden_module.out_b.grad, our_module.out_b.grad
                        )
                        backward_l1.append(f"{bit_grad_l1.item()}")

                    if golden_module.bias is not None:
                        bias_grad_l1 = F.l1_loss(
                            golden_module.bias.grad, our_module.bias.grad
                        )
                        backward_l1.append(f"{bias_grad_l1.item()}")
                    print(f"bwd: {', '.join(backward_l1)}")
                    golden_module.zero_grad()
                    our_module.zero_grad()

    print("benchmark ...")

    def get_bench_configs():
        configs = []

        line_vals = [
            f"{m}-bias{t}"
            for m in ["baseline", "ours", "vanilla"]
            for t in ["Yes", "No"]
        ]
        styles = [
            (color, style)
            for color in ["blue", "red", "green"]
            for style in ["-", ":"]
        ]
        line_names = line_vals

        configs.append(
            triton.testing.Benchmark(
                x_names=[
                    "M",
                    "N",
                    "K",
                ],  # Argument names to use as an x-axis for the plot
                x_vals=[128 * i for i in range(1, 8)]
                + [256 * i for i in range(4, 8)]
                + [512 * i for i in range(4, 8)]
                + [1024 * i for i in range(4, 9)],
                x_log=False,
                line_arg="provider",
                line_vals=line_vals,
                line_names=line_names,
                styles=styles,
                ylabel="TFLOP/s",  # Label name for the y-axis
                args={},
                plot_name="linear",
            )
        )
        return configs

    class Linear_vanilla(torch.nn.Linear):
        def __init__(
            self,
            in_features: int,
            out_features: int,
            **kwargs,
        ):
            super().__init__(in_features, out_features, **kwargs)

        @torch.autocast(dtype=torch.float16, device_type="cuda")
        def forward(self, x):
            return F.linear(x, self.weight, self.bias)

    def test_iter(module, input, num_iter=1):
        module.train()
        for _ in range(num_iter):
            module.zero_grad()
            out = module(input)
            loss = F.l1_loss(out, torch.zeros_like(out)) * 128
            loss.backward()
        return

    @triton.testing.perf_report(get_bench_configs())
    def benchmark(M, N, K, provider):
        method_part = provider.split("-")[0]
        bias_part = provider.split("-")[-1]
        init_bit = 8

        use_bias = bias_part[4:]
        if use_bias == "Yes":
            use_bias = True
        elif use_bias == "No":
            use_bias = False
        else:
            raise ValueError(bias_part)

        print(f"(M={M}, N={N}, K={K}), {provider}")
        num_channel = M
        in_features = K
        out_features = N

        dut = None
        if method_part == "baseline":
            dut = Linear_gm(
                in_features,
                out_features,
                bias=use_bias,
                seed=seed,
                act_config=QConfig(init_bit),
                wgt_config=QConfig(init_bit),
            ).to("cuda")
        elif method_part == "vanilla":
            dut = Linear_vanilla(in_features, out_features, bias=use_bias).to(
                "cuda"
            )
        elif method_part == "ours":
            dut = Linear(
                in_features,
                out_features,
                bias=use_bias,
                seed=seed,
                act_config=QConfig(init_bit),
                wgt_config=QConfig(init_bit),
            ).to("cuda")
        else:
            raise ValueError(method_part)

        # forward
        # - (B, M, K) @ (K, N) --> (B, M, N) : B * M * K * N
        # backprop
        # - w : (B, M, N) @ (B, M, K) --> (B, K, N) --> (K, N) : B * M * K * N
        # - a : (B, M, N) @ (N, K) --> (B, M, K) : B * M * K * N
        # - b : (B, M, N) * (B, M, N) --> (B, M, N) : B * M * N
        t1 = (
            torch.randn(
                [num_channel, in_features],
                dtype=torch.float16,
                device="cuda",
            )
            * 8
        )

        num_iter = 3

        quantiles = [0.5, 0.2, 0.8]
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: test_iter(dut, t1, num_iter),
            quantiles=quantiles,
            rep=1000,
        )
        perf = (
            lambda ms: (num_iter * 6 * M * N * K + M * N) * 1e-12 / (ms * 1e-3)
        )
        return perf(ms), perf(max_ms), perf(min_ms)

    benchmark.run(save_path="./")
