import os
import re
import torch
import torch.nn.functional as F

try:
    from .ttt_kernel import l2_norm_add_fused, fused_swiglu_bwd
except ImportError:
    from ttt_kernel import l2_norm_add_fused, fused_swiglu_bwd


@torch.compile()
def silu_backprop(dy: torch.Tensor, x: torch.Tensor):
    """
    Args:
        dy: [b, d, l], gradient of the outer loss wrt the y
        x: [b, d, l], input of the silu activation
    outs:
        dx: [b, d, l], gradient of the outer loss wrt the x
        dx = dy * sigma * (1 + x * (1 - sigma))
    """
    sigma = torch.sigmoid(x)
    dx = dy * sigma * (1 + x * (1 - sigma))
    return dx


@torch.compile
def bidirectional_lact_swiglu_pytorch(
    w0_w2: torch.Tensor,  # [b, 2 * dh, dk]
    w1: torch.Tensor,  # [b, dv, dh]
    q: torch.Tensor,  # [b, l, dk]
    k: torch.Tensor,  # [b, l, dk]
    v: torch.Tensor,  # [b, l, dv]
    lr0: torch.Tensor,  # [b, l, 1]
    lr1: torch.Tensor,  # [b, l, 1]
    lr2: torch.Tensor,  # [b, l, 1]
) -> torch.Tensor:
    """
    Bidirectional LaCT with SwiGLU fast weight function.
    w0, w1, w2 are the fast weights. f(x) =  w1 @ (silu(w0 @ x) * (w2 @ x))

    About precision:
        w0, w1, w2 are mostly likely fp32.
        q, k, v are fp16.
        lr0, lr1, lr2 are fp32.
        The forward, backward produce bf16 gradients, updated fast weights are fp32.
        The final output are bf16.


    FLOPS:
        (assume dk=dv denoted as D, hidden dimension of swiglu-mlp is H, ignore muon)
        Forward pass with key: 4 * D * H * L * B
        Backward pass: 8 * D * H * L * B
        Forward with Query: 6 * D * H * L * B
        Total: 18 * D * H * L * B
    Outputs:
        o: [b, l, dv]
    """

    # adding detach here sometimes improves stability.
    w0_w2_norm = w0_w2.norm(dim=2, keepdim=True)
    w1_norm = w1.norm(dim=2, keepdim=True)

    q = q.transpose(1, 2)  # [b, dk, l]
    v = v.transpose(1, 2)

    ######### update the fast weight w0, w1, w2 with test-time training #########

    #### Forward pass with key
    # [b, dh, dk] @ [b, dk, l] -> [b, dh, l]
    y0_y2 = torch.bmm(w0_w2, k.transpose(1, 2))
    gate_before_act, hidden_before_mul = y0_y2.chunk(2, dim=1)
    hidden = F.silu(gate_before_act, inplace=False) * hidden_before_mul

    #### Backward pass to compute fast weight gradients
    # [b, dh, dv] @ [b, dv, l] -> [b, dh, l]
    dhidden = torch.bmm(w1.transpose(1, 2), v)

    dhidden_before_mul = dhidden * F.silu(gate_before_act, inplace=False)
    dgate = dhidden * hidden_before_mul
    dgate_before_act = silu_backprop(dgate, gate_before_act)

    # [b, dv, l] @ [b, l, dh] -> [b, dv, dh]
    dw1 = torch.bmm(v, (hidden.transpose(1, 2) * lr1).type_as(v))  # [b, d, d]
    # [b, dh, l] @ [b, l, dk] -> [b, dh, dk]
    dy0_dy2 = torch.cat(
        [
            dgate_before_act * lr0.transpose(1, 2),
            dhidden_before_mul * lr2.transpose(1, 2),
        ],
        dim=1,
    )
    dw0_dw2 = torch.bmm(dy0_dy2, k)

    # w1 = w1 + dw1
    # w0_w2 = w0_w2 + dw0_dw2
    # w0_w2 = w0_w2 / (w0_w2.norm(dim=2, keepdim=True) + 1e-5) * w0_w2_norm
    # w1 = w1 / (w1.norm(dim=2, keepdim=True) + 1e-5) * w1_norm

    w1 = l2_norm_add_fused(w1, dw1, w1_norm)
    w0_w2 = l2_norm_add_fused(w0_w2, dw0_dw2, w0_w2_norm)

    ######### apply the updated fast weights to the query #########

    # [b, dh, dk] @ [b, dk, l] -> [b, dh, l]
    q_y0_y2 = torch.bmm(w0_w2, q)
    gate, h = q_y0_y2.chunk(2, dim=1)
    # [b, dv, dh] @ [b, dh, l] -> [b, dv, l] -> [b, l, dv]
    o = torch.bmm(w1, F.silu(gate) * h).transpose(1, 2)

    return o


@torch.compile
def bidirectional_lact_swiglu_fused(
    w0_w2: torch.Tensor,  # [b, 2 * dh, dk]
    w1: torch.Tensor,  # [b, dv, dh]
    q: torch.Tensor,  # [b, l, dk]
    k: torch.Tensor,  # [b, l, dk]
    v: torch.Tensor,  # [b, l, dv]
    lr0: torch.Tensor,  # [b, l, 1]
    lr1: torch.Tensor,  # [b, l, 1]
    lr2: torch.Tensor,  # [b, l, 1]
) -> torch.Tensor:
    """
    Bidirectional LaCT with SwiGLU fast weight function.
    w0, w1, w2 are the fast weights. f(x) =  w1 @ (silu(w0 @ x) * (w2 @ x))

    About precision:
        w0, w1, w2 are mostly likely fp32.
        q, k, v are fp16.
        lr0, lr1, lr2 are fp32.
        The forward, backward produce bf16 gradients, updated fast weights are fp32.
        The final output are bf16.


    FLOPS:
        (assume dk=dv denoted as D, hidden dimension of swiglu-mlp is H, ignore muon)
        Forward pass with key: 4 * D * H * L * B
        Backward pass: 8 * D * H * L * B
        Forward with Query: 6 * D * H * L * B
        Total: 18 * D * H * L * B
    Outputs:
        o: [b, l, dv]
    """

    # adding detach here sometimes improves stability.
    w0_w2_norm = w0_w2.norm(dim=2, keepdim=True)
    w1_norm = w1.norm(dim=2, keepdim=True)

    q = q.transpose(1, 2)  # [b, dk, l]

    ######### update the fast weight w0, w1, w2 with test-time training #########

    #### Forward pass with key
    # [b, l, dk] @ [b, dk, dh] -> [b, l, dh]
    y0_y2 = torch.bmm(k, w0_w2.transpose(1, 2))

    #### Backward pass to compute fast weight gradients
    # [b, l, dv] @ [b, dv, dh] -> [b, l, dh]
    dhidden = torch.bmm(v, w1)
    # [b, l, dh]
    # dgate_before_act, dhidden_before_mul = swiglu_bwd_fused_simple_v2(
    dw0_dw2, dw1 = fused_swiglu_bwd(
        dhidden,
        y0_y2,
        lr0,
        lr1,
        lr2,
        k,
        v,
    )

    # w1 = w1 + dw1
    # w0_w2 = w0_w2 + dw0_dw2

    # w0_w2 = w0_w2 / (w0_w2.norm(dim=2, keepdim=True) + 1e-5) * w0_w2_norm
    # w1 = w1 / (w1.norm(dim=2, keepdim=True) + 1e-5) * w1_norm
    w1 = l2_norm_add_fused(w1, dw1, w1_norm)
    w0_w2 = l2_norm_add_fused(w0_w2, dw0_dw2, w0_w2_norm)

    ######### apply the updated fast weights to the query #########

    # [b, dh, dk] @ [b, dk, l] -> [b, dh, l]
    q_y0_y2 = torch.bmm(w0_w2, q)
    gate, h = q_y0_y2.chunk(2, dim=1)
    # [b, dv, dh] @ [b, dh, l] -> [b, dv, l] -> [b, l, dv]
    o = torch.bmm(w1, F.silu(gate) * h).transpose(1, 2)

    # return w0, w1, w2
    return o


def make_inputs(
    B,
    L,
    D,
    H,
    device="cuda",
    act_dtype=torch.bfloat16,
    param_dtype=torch.bfloat16,
    lr_dtype=torch.float32,
    requires_grad=False,
):
    """
    Shapes:
      w0: [B, H, D] (fp32)
      w1: [B, D, H] (fp32)
      w2: [B, H, D] (fp32)
       q: [B, L, D] (fp16/bf16)
       k: [B, L, D] (fp16/bf16)
       v: [B, L, D] (fp16/bf16)
     lr*: [B, L, 1] (fp32)
    """
    g = requires_grad
    w0_w2 = torch.randn(B, 2 * H, D, device=device, dtype=param_dtype, requires_grad=g)
    w1 = torch.randn(B, D, H, device=device, dtype=param_dtype, requires_grad=g)
    q = torch.randn(B, L, D, device=device, dtype=act_dtype, requires_grad=g)
    k = torch.randn(B, L, D, device=device, dtype=act_dtype, requires_grad=g)
    v = torch.randn(B, L, D, device=device, dtype=act_dtype, requires_grad=g)
    lr0 = torch.randn(B, L, 1, device=device, dtype=lr_dtype, requires_grad=False)
    lr1 = torch.randn(B, L, 1, device=device, dtype=lr_dtype, requires_grad=False)
    lr2 = torch.randn(B, L, 1, device=device, dtype=lr_dtype, requires_grad=False)
    return w0_w2, w1, q, k, v, lr0, lr1, lr2


def flops_counter(B, L, D, H):
    return 18 * L * D * H * B


def run_benchmark_fwd(output_folder: str):
    try:
        from benchmark_figure_helper import BenchmarkFigureHelper
    except ImportError:
        from .benchmark_figure_helper import BenchmarkFigureHelper

    name_2_fn = {
        "bidirectional_lact_swiglu_pytorch": bidirectional_lact_swiglu_pytorch,
        "bidirectional_lact_swiglu_triton_fused": bidirectional_lact_swiglu_fused,
    }

    full_items_list = ["fwd_tflops", "memory_fwd"]

    make_input_fn = make_inputs
    benchmark_figure_helper = BenchmarkFigureHelper(
        full_items_list,
        name_2_fn,
        make_input_fn,
        output_folder,
        repeats=10,
        amp=True,
    )

    ##### Create shape params for running the benchmark
    ##### You can play with these shapes to see its effect on the speed
    # On the NVS experiment, we typically have B=8, D=768, H=1536, L=16384~0.5M tokens
    # On the LLM experiment, we typically have B=4, D=768, H=768, L=2048, 4096
    # On the Video Gen experiments, we typically habe B=1,4, D=1536, H=1536/3072, L=4680
    B_list = [4]

    D_list = [1024, 2048]

    L_list = [2048, 4096, 8192, 16384]

    H_ratio_list = [1.0]

    for B in B_list:
        for D in D_list:
            for H_ratio in H_ratio_list:
                H = int(D * H_ratio)
                data_list = []
                for L in L_list:
                    FLOPS = flops_counter(B, L, D, H)
                    name = f"B={B}_D={D}_H={H}_L={L}"
                    shape_params = [B, L, D, H]
                    data_list.append(
                        {
                            "FLOPS": FLOPS,
                            "shape_params": shape_params,
                            "name": name,
                            "x": L,
                        }
                    )
                save_name = f"B={B}_D={D}_H={H}"
                x_label = "Sequence Length"

                benchmark_figure_helper(data_list, save_name, x_label)


if __name__ == "__main__":
    import argparse

    def get_device_suffix():
        if torch.cuda.is_available():
            try:
                idx = torch.cuda.current_device()
                name = torch.cuda.get_device_name(idx)
            except Exception:
                name = "cuda"
            safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", name).strip("_")
            return safe or "cuda"
        return "cpu"

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output_folder",
        type=str,
        default="results/default_output_lact_bidirectional_triton_fused",
    )
    args = parser.parse_args()
    gpu_suffix = get_device_suffix()
    args.output_folder = os.path.join(args.output_folder, gpu_suffix)

    run_benchmark_fwd(args.output_folder)
