import os
import re
import torch
import torch.nn.functional as F


@torch.compile()
def zeropower_via_newtonschulz5(G):
    """
    This is an updated version of the zeropower_via_newtonschulz5 function in here:
    Major change: G is [b, d, d] rather than [d, d]
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
    Args:
        G: [b, d, d']
    Returns:
        X: [b, d, d']
    FLOPS:  When d=d', Total FLOPS=30 * b * d^3
    """
    assert len(G.shape) == 3
    X = G.bfloat16()
    if G.size(1) > G.size(2):
        X = X.transpose(1, 2)
    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(1, 2), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for a, b, c in [
        (4.0848, -6.8946, 2.9270),
        (3.9505, -6.3029, 2.6377),
        (3.7418, -5.5913, 2.3037),
        (2.8769, -3.1427, 1.2046),
        (2.8366, -3.0525, 1.2012),
    ]:
        A = X @ X.transpose(1, 2)
        B = b * A + c * A @ A  #
        X = a * X + B @ X

    if G.size(1) > G.size(2):
        X = X.transpose(1, 2)
    return X


@torch.compile()
def silu_backprop(dy: torch.Tensor, x: torch.Tensor):
    """
    Args:
        dy: [b, d, l], gradient of the outer loss wrt the y
        x: [b, d, l], input of the silu activation
    outs:
        dx: [b, d, l], gradient of the outer loss wrt the x
        dx = dy * sigma * (1 + x * (1 - sigma))
    """
    sigma = torch.sigmoid(x)
    dx = dy * sigma * (1 + x * (1 - sigma))
    return dx


@torch.compile
def bidirectional_lact_swiglu(
    w0: torch.Tensor,  # [b, dh, dk]
    w1: torch.Tensor,  # [b, dv, dh]
    w2: torch.Tensor,  # [b, dh, dk]
    q: torch.Tensor,  # [b, l, dk]
    k: torch.Tensor,  # [b, l, dk]
    v: torch.Tensor,  # [b, l, dv]
    lr0: torch.Tensor,  # [b, l, 1]
    lr1: torch.Tensor,  # [b, l, 1]
    lr2: torch.Tensor,  # [b, l, 1]
    use_muon: bool = False,
) -> torch.Tensor:
    """
    Bidirectional LaCT with SwiGLU fast weight function.
    w0, w1, w2 are the fast weights. f(x) =  w1 @ (silu(w0 @ x) * (w2 @ x))

    About precision:
        w0, w1, w2 are mostly likely fp32.
        q, k, v are fp16.
        lr0, lr1, lr2 are fp32.
        The forward, backward produce bf16 gradients, updated fast weights are fp32.
        The final output are bf16.


    FLOPS:
        (assume dk=dv denoted as D, hidden dimension of swiglu-mlp is H, ignore muon)
        Forward pass with key: 4 * D * H * L * B
        Backward pass: 8 * D * H * L * B
        Forward with Query: 6 * D * H * L * B
        Total: 18 * D * H * L * B
    Outputs:
        o: [b, l, dv]
    """

    # adding detach here sometimes improves stability.
    w0_norm = w0.norm(dim=2, keepdim=True)
    w1_norm = w1.norm(dim=2, keepdim=True)
    w2_norm = w2.norm(dim=2, keepdim=True)

    q = q.transpose(1, 2)  # [b, dk, l]
    v = v.transpose(1, 2)

    ######### update the fast weight w0, w1, w2 with test-time training #########

    #### Forward pass with key
    # [b, dh, dk] @ [b, dk, l] -> [b, dh, l]
    gate_before_act = torch.bmm(w0, k.transpose(1, 2))
    hidden_before_mul = torch.bmm(w2, k.transpose(1, 2))
    hidden = F.silu(gate_before_act, inplace=False) * hidden_before_mul

    #### Backward pass to compute fast weight gradients
    # [b, dh, dv] @ [b, dv, l] -> [b, dh, l]
    dhidden = torch.bmm(w1.transpose(1, 2), v)

    dhidden_before_mul = dhidden * F.silu(gate_before_act, inplace=False)
    dgate = dhidden * hidden_before_mul
    dgate_before_act = silu_backprop(dgate, gate_before_act)

    # [b, dv, l] @ [b, l, dh] -> [b, dv, dh]
    dw1 = torch.bmm(v, (hidden.transpose(1, 2) * lr1).type_as(v))  # [b, d, d]
    # [b, dh, l] @ [b, l, dk] -> [b, dh, dk]
    dw0 = torch.bmm(dgate_before_act, (k * lr0).type_as(dgate_before_act))
    dw2 = torch.bmm(dhidden_before_mul, (k * lr2).type_as(dhidden_before_mul))

    if use_muon:
        w0 = zeropower_via_newtonschulz5(dw0)
        w1 = zeropower_via_newtonschulz5(dw1)
        w2 = zeropower_via_newtonschulz5(dw2)

    w1 = w1 + dw1
    w0 = w0 + dw0
    w2 = w2 + dw2

    w0 = w0 / (w0.norm(dim=2, keepdim=True) + 1e-5) * w0_norm
    w1 = w1 / (w1.norm(dim=2, keepdim=True) + 1e-5) * w1_norm
    w2 = w2 / (w2.norm(dim=2, keepdim=True) + 1e-5) * w2_norm

    ######### apply the updated fast weights to the query #########

    # [b, dh, dk] @ [b, dk, l] -> [b, dh, l]
    h = torch.bmm(w2, q)
    gate = F.silu(torch.bmm(w0, q), inplace=True)
    # [b, dv, dh] @ [b, dh, l] -> [b, dv, l] -> [b, l, dv]
    o = torch.bmm(w1, gate * h).transpose(1, 2)

    return o


def make_inputs(
    B,
    L,
    D,
    H,
    device="cuda",
    act_dtype=torch.bfloat16,
    param_dtype=torch.bfloat16,
    lr_dtype=torch.float32,
    requires_grad=False,
):
    w0 = torch.randn(
        B, H, D, device=device, dtype=param_dtype, requires_grad=requires_grad
    )
    w1 = torch.randn(
        B, D, H, device=device, dtype=param_dtype, requires_grad=requires_grad
    )
    w2 = torch.randn(
        B, H, D, device=device, dtype=param_dtype, requires_grad=requires_grad
    )
    q = torch.randn(
        B, L, D, device=device, dtype=act_dtype, requires_grad=requires_grad
    )
    k = torch.randn(
        B, L, D, device=device, dtype=act_dtype, requires_grad=requires_grad
    )
    v = torch.randn(
        B, L, D, device=device, dtype=act_dtype, requires_grad=requires_grad
    )
    lr0 = torch.randn(
        B, L, 1, device=device, dtype=lr_dtype, requires_grad=requires_grad
    )
    lr1 = torch.randn(
        B, L, 1, device=device, dtype=lr_dtype, requires_grad=requires_grad
    )
    lr2 = torch.randn(
        B, L, 1, device=device, dtype=lr_dtype, requires_grad=requires_grad
    )
    return w0, w1, w2, q, k, v, lr0, lr1, lr2


def flops_counter(B, L, D, H):
    return 18 * L * D * H * B


def run_benchmark_fwd(output_folder: str):
    try:
        from benchmark_figure_helper import BenchmarkFigureHelper
    except ImportError:
        from .benchmark_figure_helper import BenchmarkFigureHelper

    name_2_fn = {
        "bidirectional_lact_swiglu": bidirectional_lact_swiglu,
    }

    full_items_list = ["fwd_tflops", "memory_fwd"]

    make_input_fn = make_inputs
    benchmark_figure_helper = BenchmarkFigureHelper(
        full_items_list,
        name_2_fn,
        make_input_fn,
        output_folder,
        repeats=10,
        amp=True,
    )

    ##### Create shape params for running the benchmark
    ##### You can play with these shapes to see its effect on the speed
    # On the NVS experiment, we typically have B=8, D=768, H=1536, L=16384~0.5M tokens
    # On the LLM experiment, we typically have B=4, D=768, H=768, L=2048, 4096
    # On the Video Gen experiments, we typically habe B=1,4, D=1536, H=1536/3072, L=4680

    B_list = [4]

    D_list = [512, 1024, 2048]

    L_list = [2048, 4096, 8192, 16384]

    H_ratio_list = [1.0, 2.0]

    for B in B_list:
        for D in D_list:
            for H_ratio in H_ratio_list:
                H = int(D * H_ratio)
                data_list = []
                for L in L_list:
                    FLOPS = flops_counter(B, L, D, H)
                    name = f"B={B}_D={D}_H={H}_L={L}"
                    shape_params = [B, L, D, H]
                    data_list.append(
                        {
                            "FLOPS": FLOPS,
                            "shape_params": shape_params,
                            "name": name,
                            "x": L,
                        }
                    )
                save_name = f"B={B}_D={D}_H={H}"
                x_label = "Sequence Length"

                benchmark_figure_helper(data_list, save_name, x_label)


if __name__ == "__main__":
    import argparse

    def get_device_suffix():
        if torch.cuda.is_available():
            try:
                idx = torch.cuda.current_device()
                name = torch.cuda.get_device_name(idx)
            except Exception:
                name = "cuda"
            safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", name).strip("_")
            return safe or "cuda"
        return "cpu"

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output_folder",
        type=str,
        default="results/default_output_lact_bidirectional_pytorch",
    )
    args = parser.parse_args()
    gpu_suffix = get_device_suffix()
    args.output_folder = os.path.join(args.output_folder, gpu_suffix)

    run_benchmark_fwd(args.output_folder)
