import torch
import torch.nn.functional as F
from torch import nn
# from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import distribute_tensor

import json
from dataclasses import asdict, dataclass
from typing import List

# from ttt_ops import ttt_linear, ttt_mlp
from ttt_ops.mlp_tk import TkMLP

import torch.utils.benchmark as benchmark

@torch.compiler.disable
def full_tensor(tensor: torch.Tensor | DTensor) -> torch.Tensor:
    """
    Convert a DTensor to a local replicalated tensor
    """
    if isinstance(tensor, DTensor):
        return tensor.full_tensor()

    return tensor


@dataclass
class ModelConfig:
    model_dim: int
    num_heads: int
    num_layers: int

    ssm_layer: str = "ttt_mlp"
    layer_norm_eps: float = 1e-6

    # TTT-Specific Configs
    mini_batch_size: int = 64
    ttt_base_lr: float = 0.1

    rope_theta: float = 10000
    scan_checkpoint_group_size: int = 16

    adapter_method: str = "none"  # none, sft, qkvo

    # Network Config
    time_embed_dim: int = 512
    sigma_interval: int = 1000
    patch_size: int = 2
    in_channels: int = 16
    out_channels: int = 16
    scale_factor: float = 1.0

    # ROPE Config
    latent_height: int = 30
    latent_width: int = 45
    compressed_num_frames: int = 13
    theta: float = 10000

    # Conditioner Config
    text_dim: int = 512

    # SSM Attn Config
    gating_alpha_init: float = 0.1
    attn_length: int = 12
    prefix_temporal_length: int = 1

    # Remat config
    remat_transformer_layer_group_size: int = 1
    remat_forward_ssm: bool = False
    remat_reverse_ssm: bool = False
    remat_attention: bool = False
    remat_mlp: bool = False
    remat_seq_modeling_block: bool = False
    shard_transformer_inputs: bool = False

    PREDEFINED_CONFIGS = {
        "debug": {
            "model_dim": 512,
            "num_heads": 8,
            "num_layers": 6,
        },
        "5B": {
            "model_dim": 3072,
            "num_heads": 48,
            "num_layers": 42,
            "text_dim": 4096,
        },
    }

    VIDEO_DURATION_CONFIGS = {
        "3sec": {
            "compressed_num_frames": 13,
        },
        "9sec": {
            "compressed_num_frames": 37,
        },
        "18sec": {
            "compressed_num_frames": 73,
        },
        "30sec": {
            "compressed_num_frames": 121,
        },
        "63sec": {
            "compressed_num_frames": 253,
        },
    }



class TTTBase(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.width = config.model_dim
        self.num_heads = config.num_heads
        self.head_dim = config.model_dim // config.num_heads
        self.mini_batch_size = config.mini_batch_size

        self.ttt_base_lr = config.ttt_base_lr
        self.scan_checkpoint_group_size = config.scan_checkpoint_group_size

        self.tp_mesh: None | DeviceMesh = None

        self._init_qkvo_proj()
        self._init_ttt_lr_gate()
        self._init_ttt_ln()

        self.post_norm = nn.LayerNorm(self.width, eps=1e-6)

    # We must reinitialize after meta initialization
    def init_weights(self):
        for linear in (self.wq, self.wk, self.wv):
            nn.init.normal_(linear.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.wo.weight, mean=0.0, std=0.02)

        self.post_norm.reset_parameters()
        nn.init.ones_(self.ttt_norm_weight.data)
        nn.init.zeros_(self.ttt_norm_bias)
        nn.init.normal_(self.learnable_ttt_lr_weight, mean=0.0, std=0.02)
        nn.init.zeros_(self.learnable_ttt_lr_bias)

    def _init_qkvo_proj(self):
        self.wq = nn.Linear(self.width, self.num_heads * self.head_dim, bias=True)
        self.wk = nn.Linear(self.width, self.num_heads * self.head_dim, bias=True)
        self.wv = nn.Linear(self.width, self.num_heads * self.head_dim, bias=True)
        self.wo = nn.Linear(self.width, self.num_heads * self.head_dim, bias=True)

    def _init_ttt_lr_gate(self):
        linear_weight_data = nn.Linear(self.width, 1, bias=True).weight.data
        self.learnable_ttt_lr_weight = nn.Parameter(
            torch.stack(
                [torch.normal(0, 0.02, size=linear_weight_data.shape) for _ in range(self.num_heads)],
                dim=0,
            )
        )

        linear_bias_data = nn.Linear(self.width, 1, bias=True).bias.data
        self.learnable_ttt_lr_bias = nn.Parameter(
            torch.stack(
                [torch.zeros_like(linear_bias_data) for _ in range(self.num_heads)],
                dim=0,
            )
        )

    def _init_ttt_ln(self):
        ln_weight_data = nn.LayerNorm(self.head_dim).weight.data
        self.ttt_norm_weight = nn.Parameter(torch.tile(ln_weight_data.unsqueeze(0), (self.num_heads, 1)))
        ln_bias_data = nn.LayerNorm(self.head_dim).bias.data
        self.ttt_norm_bias = nn.Parameter(torch.tile(ln_bias_data.unsqueeze(0), (self.num_heads, 1)))

    def init_device_mesh(self, tp_mesh: DeviceMesh):
        self.tp_mesh = tp_mesh

        self.ttt_norm_weight = nn.Parameter(distribute_tensor(self.ttt_norm_weight, tp_mesh, [Shard(0)]))
        self.ttt_norm_bias = nn.Parameter(distribute_tensor(self.ttt_norm_bias, tp_mesh, [Shard(0)]))

        self.learnable_ttt_lr_weight = nn.Parameter(
            distribute_tensor(self.learnable_ttt_lr_weight, tp_mesh, [Replicate()])
        )
        self.learnable_ttt_lr_bias = nn.Parameter(distribute_tensor(self.learnable_ttt_lr_bias, tp_mesh, [Replicate()]))

    @torch.compile
    def get_qkv_projections(self, hidden_states):
        XQ, XK, XV = (
            self.wq(hidden_states),
            self.wk(hidden_states),
            self.wv(hidden_states),
        )
        return XQ, XK, XV

    @torch.compile
    def get_eta(self, X):
        learnable_ttt_lr_weight = full_tensor(self.learnable_ttt_lr_weight)
        learnable_ttt_lr_bias = full_tensor(self.learnable_ttt_lr_bias)

        ttt_lr = torch.einsum("bnkc,hdc->bhnkd", X, learnable_ttt_lr_weight) + learnable_ttt_lr_bias.reshape(
            1, -1, 1, 1, 1
        )  # [B,nc,cs,c] @ [nh,1,c] -> [B,nh,nc,cs,1] + [1,nh,1,1,1] -> [B,nh,nc,cs,1]

        ttt_lr = F.sigmoid(ttt_lr)  # [B,H,nc,K,1]

        ttt_lr = ttt_lr.permute(0, 1, 2, 4, 3)
        return self.ttt_base_lr * ttt_lr / self.head_dim


    @torch.compile
    def ln_reconstruction_target(self, XV, XK):
        XV = XV - XK
        eps = 1e-8
        # Compute mean and std over the head dimension (last dimension)
        mean = XV.mean(dim=-1, keepdim=True)
        std = XV.std(dim=-1, keepdim=True)

        # Normalize
        XV = (XV - mean) / (std + eps)

        # Apply per-head weight and bias.
        # self.ttt_norm_weight and self.ttt_norm_bias have shape [num_heads, head_dim].
        # We unsqueeze to make them broadcastable with XV_norm which is [B, L, num_heads, head_dim].
        XV = self.ttt_norm_weight.unsqueeze(0).unsqueeze(0) * XV + self.ttt_norm_bias.unsqueeze(0).unsqueeze(0)

        return XV + XK

    @torch.compile
    def reshape_to_mini_batch(self, X, XQ, XK, XV):
        B, L = X.shape[:2]
        num_mini_batch = L // self.mini_batch_size

        XQ, XK, XV = XQ.transpose(1, 2), XK.transpose(1, 2), XV.transpose(1, 2)

        X = X.reshape(B, num_mini_batch, self.mini_batch_size, self.width)

        XQ = XQ.reshape(B, self.num_heads, num_mini_batch, self.mini_batch_size, self.head_dim)
        XK = XK.reshape(B, self.num_heads, num_mini_batch, self.mini_batch_size, self.head_dim)
        XV = XV.reshape(B, self.num_heads, num_mini_batch, self.mini_batch_size, self.head_dim)

        return X, XQ, XK, XV

    def process_input(self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor):
        

        B, L = hidden_states.shape[:2]
        mini_batch_size = self.mini_batch_size

        XQ, XK, XV = self.get_qkv_projections(hidden_states)

        XQ = XQ.view(B, L, -1, self.head_dim)
        XK = XK.view(B, L, -1, self.head_dim)
        XV = XV.view(B, L, -1, self.head_dim)

        # L2 Norm
        XQ = torch.nn.functional.normalize(XQ, p=2, dim=-1)
        XK = torch.nn.functional.normalize(XK, p=2, dim=-1)


        XV = self.ln_reconstruction_target(XV, XK)

        hidden_states, XQ, XK, XV = self.reshape_to_mini_batch(hidden_states, XQ, XK, XV)

        ttt_lr_eta = self.get_eta(hidden_states)

        # We do not use token_eta for non-causal chunks
        eta = 1 / mini_batch_size * ttt_lr_eta.repeat(1, 1, 1, mini_batch_size, 1)

        inputs = {
            "XQ": XQ,
            "XK": XK,
            "XV": XV,
            "eta": eta,
        }

        return inputs

    def ttt(
        self,
        inputs,
    ):
        raise NotImplementedError("ttt method must be implemented in TTTBase subclasses.")

    def forward(
        self,
        hidden_states: torch.Tensor,
        freqs_cis: torch.Tensor = None,
    ):
        assert (
            hidden_states.size(1) % self.config.mini_batch_size == 0
        ), "Sequence len must be multiple of mini batch size."

        hidden_states = self.ttt(self.process_input(hidden_states, freqs_cis))

        hidden_states = self.post_norm(hidden_states)
        hidden_states = self.wo(hidden_states)

        hidden_states = full_tensor(hidden_states)

        return hidden_states
    


class TTTMLP(TTTBase):
    def __init__(self, config: ModelConfig, use_kernel: bool = True):
        super().__init__(config)
        self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, 4 * self.head_dim)))
        self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, 4 * self.head_dim))
        self.W2 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, 4 * self.head_dim, self.head_dim)))
        self.b2 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim))

        self.use_kernel = use_kernel

    def init_weights(self):
        super().init_weights()
        nn.init.normal_(self.W1, mean=0.0, std=0.02)
        nn.init.zeros_(self.b1)
        nn.init.normal_(self.W2, mean=0.0, std=0.02)
        nn.init.zeros_(self.b2)

    def init_device_mesh(self, tp_mesh: DeviceMesh):
        assert self.use_kernel, "Tensor parallel is not currently supported for TTTMLP without kernel."
        super().init_device_mesh(tp_mesh)

        self.W1 = nn.Parameter(distribute_tensor(self.W1, tp_mesh, [Shard(0)]))
        self.b1 = nn.Parameter(distribute_tensor(self.b1, tp_mesh, [Shard(0)]))
        self.W2 = nn.Parameter(distribute_tensor(self.W2, tp_mesh, [Shard(0)]))
        self.b2 = nn.Parameter(distribute_tensor(self.b2, tp_mesh, [Shard(0)]))

        TkMLP.sharded_mode = True

    def ttt(self, inputs):
        B = inputs["XV"].shape[0]
        num_mini_batch = inputs["XV"].shape[2]
        L = inputs["XV"].shape[2] * inputs["XV"].shape[3]
        
        # [B, nheads, hdim, 4*hdim]
        W1_states = torch.tile(self.W1.unsqueeze(0), dims=(B, 1, 1, 1))
        # [B, nheads, 1, 4*hdim]
        b1_states = torch.tile(self.b1.unsqueeze(0), dims=(B, 1, 1, 1))
        # [B, nheads, 4*hdim, hdim]
        W2_states = torch.tile(self.W2.unsqueeze(0), dims=(B, 1, 1, 1))
        # [B, nheads, 1, hdim]
        b2_states = torch.tile(self.b2.unsqueeze(0), dims=(B, 1, 1, 1))

        checkpoint_group_size = min(max(self.config.scan_checkpoint_group_size, 1), num_mini_batch)

        if self.use_kernel:
            XQW_batch = TkMLP.apply(
                self.ttt_norm_weight, # [num_heads, hdim]
                self.ttt_norm_bias, # [num_heads, hdim]
                W1_states, # [B, nheads, hdim, 4*hdim]
                b1_states, # [B, nheads, 1, 4*hdim]
                W2_states, # [B, nheads, 4*hdim, hdim]
                b2_states, # [B, nheads, 1, hdim]
                inputs["XQ"],  # [B, num_heads, num_mini_batch, mini_batch_size, hdim]
                inputs["XV"],  # [B, num_heads, num_mini_batch, mini_batch_size, hdim]
                inputs["XK"],  # [B, num_heads, num_mini_batch, mini_batch_size, hdim]
                inputs["eta"],  # [B, num_heads, num_mini_batch, mini_batch_size, 1]
                checkpoint_group_size,
            )

            XQW_batch = XQW_batch.permute(0, 2, 3, 1, 4)
        else:
            XQW_batch = ttt_mlp(
                inputs["XK"],
                inputs["XQ"],
                inputs["XV"],
                inputs["eta"],
                self.ttt_norm_weight,
                self.ttt_norm_bias,
                W1_states,
                b1_states,
                W2_states,
                b2_states,
                checkpoint_group_size,
            )

        XQW_batch = XQW_batch.reshape(B, L, self.width)
        return XQW_batch

def make_inputs(B=1, L=2048, model_dim=1024, head_dim=64, mini_batch_size=16, device=torch.device("cuda"), dtype=torch.bfloat16):
    num_heads = model_dim // head_dim
    num_mini_batch = L // mini_batch_size
    
    # Create input tensors
    XQ = torch.randn(B, num_heads, num_mini_batch, mini_batch_size, head_dim, device=device, dtype=dtype)
    XV = torch.randn(B, num_heads, num_mini_batch, mini_batch_size, head_dim, device=device, dtype=dtype)
    XK = torch.randn(B, num_heads, num_mini_batch, mini_batch_size, head_dim, device=device, dtype=dtype)
    eta = torch.rand(B, num_heads, num_mini_batch, mini_batch_size, mini_batch_size, device=device, dtype=dtype) * 0.1  # Learning rate factor
    
    # Create model parameters
    ttt_norm_weight = torch.ones(num_heads, head_dim, device=device, dtype=dtype)
    ttt_norm_bias = torch.zeros(num_heads, head_dim, device=device, dtype=dtype)
    
    # MLP weights and biases
    W1 = torch.randn(B, num_heads, head_dim, 4 * head_dim, device=device, dtype=dtype) * 0.02
    b1 = torch.zeros(B, num_heads, 1, 4 * head_dim, device=device, dtype=dtype)
    W2 = torch.randn(B, num_heads, 4 * head_dim, head_dim, device=device, dtype=dtype) * 0.02
    b2 = torch.zeros(B, num_heads, 1, head_dim, device=device, dtype=dtype)
    
    # inputs = {
    #     "XQ": XQ,
    #     "XV": XV,
    #     "XK": XK,
    #     "eta": eta,
    #     "ttt_norm_weight": ttt_norm_weight,
    #     "ttt_norm_bias": ttt_norm_bias,
    #     "W1": W1,
    #     "b1": b1,
    #     "W2": W2,
    #     "b2": b2,
    # }
    
    # return inputs
    return ttt_norm_weight, ttt_norm_bias, W1, b1, W2, b2, XQ, XV, XK, eta, min(num_mini_batch, 8)

def test_forward():

    B = 8
    L = 16384
    model_dim = 768
    head_dim = 64
    mini_batch_size = 16
    inputs = make_inputs(B, L, model_dim, head_dim, mini_batch_size)

    num_mini_batch = L // mini_batch_size
    checkpoint_group_size = num_mini_batch
    output = TkMLP.apply(
        *inputs,
    )

    # torch.Size([1, 16, 128, 16, 64])
    # [B, num_heads, num_mini_batch, mini_batch_size, hdim]
    print(output.shape)

def test_bwd():
    B = 8
    L = 16384
    model_dim = 1536 # 768
    head_dim = 64
    mini_batch_size = 64
    inputs = make_inputs(B, L, model_dim, head_dim, mini_batch_size)

    inputs[5].requires_grad = True
    output = TkMLP.apply(
        *inputs,
    )

    print(output.shape)
    loss = output.sum()
    loss.backward()

def benchmark_fwd(compiled_fwd_fn, inputs, n_repeats=10):
    """
    fwd_pass benchmark
    """
    t_compiled = benchmark.Timer(
        stmt="compiled_fwd_fn(*inputs)",
        globals={"compiled_fwd_fn": compiled_fwd_fn, "inputs": inputs},
        num_threads=torch.get_num_threads(),
    )
    measurment = t_compiled.timeit(n_repeats)
    
    return measurment

def test_speed():
    B = 4
    L_range = [65536] #  [32768] # [4096, 8192, 16384, 32768]
    model_dim_range = [768, 1536, 3072, 4096]
    head_dim = 64
    mini_batch_size = 64
    for L in L_range:
        for model_dim in model_dim_range:
            print(f"Configuration: B: {B}, L: {L}, D: {model_dim}, H: {head_dim}, mini_batch_size: {mini_batch_size}")
            inputs = make_inputs(B, L, model_dim, head_dim, mini_batch_size)

            compiled_fwd_fn = TkMLP.apply

            compiled_fwd_fn = torch.compile(compiled_fwd_fn)

            measurment = benchmark_fwd(compiled_fwd_fn, inputs, n_repeats=10)
            
            FLOPS_per_sample = 14 * model_dim * head_dim * 4 * L * B
            gflops = FLOPS_per_sample / measurment.mean / 1e9
            tflops = gflops / 1e3
            print(f"Achieved throughput  : ({tflops:.2f} TFLOP/s)")


def compute_fwd_iters_per_second():
    B = 1
    L = 65536

    head_dim = 64

    model_dim_range = [768, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 24576, 32768]
    
    mini_batch_size = 64
    n_repeats = 10

    for model_dim in model_dim_range:
        state_size = model_dim * head_dim * 8 / 1e6 # MB
        print(f"Configuration: B: {B}, L: {L}, D: {model_dim}, H: {head_dim}, mini_batch_size: {mini_batch_size}, state_size: {state_size} MB")
        inputs = make_inputs(B, L, model_dim, head_dim, mini_batch_size)

        compiled_fwd_fn = torch.compile(TkMLP.apply)

        measurment = benchmark_fwd(compiled_fwd_fn, inputs, n_repeats)

        time_per_call = measurment.mean

        print(f"Time per call: {time_per_call*1e3:.2f} ms")



if __name__ == "__main__":
    # test_forward()
    # test_speed()
    # test_bwd()
    # compute_fwd_iters_per_second()

    test_speed()