import torch
import torch.nn.functional as F
from torch import nn

# from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import distribute_tensor

import json
from dataclasses import asdict, dataclass
from typing import List

from ttt_linear_kernels.linear_triton import TritonLinear
import torch.utils.benchmark as benchmark
import time


@torch.compiler.disable
def full_tensor(tensor: torch.Tensor | DTensor) -> torch.Tensor:
    """
    Convert a DTensor to a local replicalated tensor
    """
    if isinstance(tensor, DTensor):
        return tensor.full_tensor()

    return tensor


def make_inputs(
    B=1,
    L=2048,
    model_dim=1024,
    head_dim=64,
    mini_batch_size=16,
    device=torch.device("cuda"),
    dtype=torch.bfloat16,
):
    num_heads = model_dim // head_dim
    num_mini_batch = L // mini_batch_size

    # Create input tensors
    XQ = torch.randn(
        B,
        num_heads,
        num_mini_batch,
        mini_batch_size,
        head_dim,
        device=device,
        dtype=dtype,
    ).contiguous()
    XV = torch.randn(
        B,
        num_heads,
        num_mini_batch,
        mini_batch_size,
        head_dim,
        device=device,
        dtype=dtype,
    ).contiguous()
    XK = torch.randn(
        B,
        num_heads,
        num_mini_batch,
        mini_batch_size,
        head_dim,
        device=device,
        dtype=dtype,
    ).contiguous()
    eta = (
        torch.rand(
            B,
            num_heads,
            num_mini_batch,
            mini_batch_size,
            mini_batch_size,
            device=device,
            dtype=dtype,
        )
        * 0.1
    ).contiguous()  # Learning rate factor

    # Create model parameters
    ttt_norm_weight = torch.ones(
        num_heads, head_dim, device=device, dtype=dtype
    ).contiguous()
    ttt_norm_bias = torch.zeros(
        num_heads, head_dim, device=device, dtype=dtype
    ).contiguous()

    # MLP weights and biases
    W1 = (
        torch.randn(B, num_heads, head_dim, head_dim, device=device, dtype=dtype) * 0.02
    ).contiguous()
    b1 = torch.zeros(B, num_heads, 1, head_dim, device=device, dtype=dtype).contiguous()
    # inputs = {
    #     "XQ": XQ,
    #     "XV": XV,
    #     "XK": XK,
    #     "eta": eta,
    #     "ttt_norm_weight": ttt_norm_weight,
    #     "ttt_norm_bias": ttt_norm_bias,
    #     "W1": W1,
    #     "b1": b1,
    #     "W2": W2,
    #     "b2": b2,
    # }

    # return inputs
    return (
        ttt_norm_weight,
        ttt_norm_bias,
        W1,
        b1,
        XQ,
        XV,
        XK,
        eta,
        min(num_mini_batch, 64),
    )


def test_forward():

    B = 8
    L = 32768
    model_dim = 1536
    head_dim = 128
    mini_batch_size = 16
    inputs = make_inputs(B, L, model_dim, head_dim, mini_batch_size)

    num_mini_batch = L // mini_batch_size
    checkpoint_group_size = num_mini_batch
    output = TritonLinear.apply(
        *inputs,
    )

    # torch.Size([1, 16, 128, 16, 64])
    # [B, num_heads, num_mini_batch, mini_batch_size, hdim]
    print(output.shape)


def test_fwd_bwd():
    B = 1
    L = 16384
    model_dim = 768
    head_dim = 256  # 128 # 64 # 128
    mini_batch_size = 16
    inputs = make_inputs(B, L, model_dim, head_dim, mini_batch_size)

    inputs[4].requires_grad = True

    output = TritonLinear.apply(
        *inputs,
    )

    loss = output.sum()
    loss.backward()

    print(inputs[4].grad.shape, inputs[4].grad.sum())


def benchmark_fwd(compiled_fwd_fn, inputs, n_repeats=10):
    """
    fwd_pass benchmark
    """
    t_compiled = benchmark.Timer(
        stmt="compiled_fwd_fn(*inputs)",
        globals={"compiled_fwd_fn": compiled_fwd_fn, "inputs": inputs},
        num_threads=torch.get_num_threads(),
    )
    measurment = t_compiled.timeit(n_repeats)
    return measurment


@torch.no_grad()
def test_speed():
    B = 4
    L_range = [65536]
    model_dim_range = [768, 1536, 3072]
    head_dim_range = [64, 128, 256]
    mini_batch_size = 16  # 64 # 16
    for L in L_range:
        for head_dim in head_dim_range:
            for model_dim in model_dim_range:
                print(
                    f"Configuration: B: {B}, L: {L}, D: {model_dim}, H: {head_dim}, mini_batch_size: {mini_batch_size}"
                )
                inputs = make_inputs(B, L, model_dim, head_dim, mini_batch_size)

                compiled_fwd_fn = torch.compile(TritonLinear.apply)

                measurment = benchmark_fwd(compiled_fwd_fn, inputs, n_repeats=10)

                FLOPS_per_sample = model_dim * head_dim * 4 * L * B
                gflops = FLOPS_per_sample / measurment.mean / 1e9
                tflops = gflops / 1e3
                print(f"Achieved throughput  : ({tflops:.2f} TFLOP/s)")


def compute_fwd_iters_per_second():
    B = 1
    L = 65536
    model_dim_range = [
        768,
        1536,
        2048,
        3072,
        4096,
        6144,
        8192,
        12288,
        16384,
        24576,
        32768,
    ]

    model_dim_range = [40960]
    head_dim_range = [128]
    mini_batch_size = 16

    n_repeats = 10

    for model_dim in model_dim_range:
        for head_dim in head_dim_range:
            state_size = model_dim * head_dim / 1e6  # MB
            print(
                f"Configuration: B: {B}, L: {L}, D: {model_dim}, H: {head_dim}, mini_batch_size: {mini_batch_size}, state_size: {state_size} MB"
            )
            inputs = make_inputs(B, L, model_dim, head_dim, mini_batch_size)

            compiled_fwd_fn = torch.compile(TritonLinear.apply)

            measurment = benchmark_fwd(compiled_fwd_fn, inputs, n_repeats)

            time_per_call = measurment.mean

            FLOPS_per_sample = model_dim * head_dim * 4 * L * B
            gflops = FLOPS_per_sample / time_per_call / 1e9
            print(f"Time per call: {time_per_call*1e3:.2f} ms")


if __name__ == "__main__":
    # test_forward()
    test_speed()

    # test_fwd_bwd()

    # compute_fwd_iters_per_second()
