import torch
from time import perf_counter

from mamba_simple import Mamba as MambaSimple
from mamba_torch_new import Mamba as MambaTorchNew
from super_fast_mamba import Mamba

from accelerate import Accelerator

# accelerator = Accelerator(mixed_precision="fp16")
accelerator = Accelerator()

# batch_size, seq_len, feature_dim = 1, 128, 64
batch_size, seq_len, feature_dim = 2, 2048, 256
warmup_iters = 5
test_iters = 25


def copy_weights(source_model, target_model):
    target_model.in_proj.weight.data = source_model.in_proj.weight.data.clone()
    if source_model.in_proj.bias is not None:
        target_model.in_proj.bias.data = source_model.in_proj.bias.data.clone()
    target_model.conv1d.weight.data = source_model.conv1d.weight.data.clone()
    if source_model.conv1d.bias is not None:
        target_model.conv1d.bias.data = source_model.conv1d.bias.data.clone()
    target_model.x_proj.weight.data = source_model.x_proj.weight.data.clone()
    if source_model.x_proj.bias is not None:
        target_model.x_proj.bias.data = source_model.x_proj.bias.data.clone()
    target_model.dt_proj.weight.data = source_model.dt_proj.weight.data.clone()
    if source_model.dt_proj.bias is not None:
        target_model.dt_proj.bias.data = source_model.dt_proj.bias.data.clone()
    target_model.A_log.data = source_model.A_log.data.clone()
    target_model.D.data = source_model.D.data.clone()
    target_model.out_proj.weight.data = source_model.out_proj.weight.data.clone()
    if source_model.out_proj.bias is not None:
        target_model.out_proj.bias.data = source_model.out_proj.bias.data.clone()


def benchmark_model(name, model, x, warmup_iters, test_iters, backward=True):
    for _ in range(warmup_iters):
        with torch.no_grad():
            with accelerator.autocast():
                out = model(x)
    torch.cuda.synchronize()

    start_time = perf_counter()
    for _ in range(test_iters):
        with torch.no_grad():
            with accelerator.autocast():
                out = model(x)
    torch.cuda.synchronize()
    end_time = perf_counter()
    print(
        f"{name} forward time: {(end_time - start_time) *1000 / test_iters:.6f} milliseconds per iteration"
    )

    if backward:

        for _ in range(warmup_iters):
            with accelerator.autocast():
                out = model(x)
                loss = out.sum()
                accelerator.backward(loss)
                model.zero_grad()
        torch.cuda.synchronize()
        start_time = perf_counter()
        for _ in range(test_iters):
            with accelerator.autocast():
                out = model(x)
                loss = out.sum()
                accelerator.backward(loss)
                model.zero_grad()
        torch.cuda.synchronize()
        end_time = perf_counter()
        print(
            f"{name} forward+backward time: {(end_time - start_time) *1000 / test_iters:.6f} milliseconds per iteration"
        )


x = torch.randn(batch_size, seq_len, feature_dim, dtype=torch.float32, device="cuda")

# ModelSimple with fast path

model_simple = MambaSimple(dim=feature_dim).to(x.device)
model_simple = accelerator.prepare(model_simple)
model_simple.compile()

benchmark_model("Simple Mamba", model_simple, x, warmup_iters, test_iters)

out_simple = model_simple(x)
loss = out_simple.sum()
accelerator.backward(loss)

print("Simple Mamba output shape:", out_simple.shape)


def test_model(name, model, x, backward=True):
    print("Comparing outputs...")
    with accelerator.autocast():
        with torch.no_grad():
            out = model(x)

    print("Mamba output shape:", out.shape)
    try:
        torch.testing.assert_close(out_simple, out, rtol=1e-4, atol=1e-4)
        print("Outputs are close!")
    except AssertionError as e:
        print("Outputs are not close!")
        print(e)

    if backward:

        print("Comparing gradients...")
        with accelerator.autocast():
            model.zero_grad()
            out = model(x)
            loss = out.sum()
            accelerator.backward(loss)

        try:
            for (name1, param1), (name2, param2) in zip(
                model_simple.named_parameters(), model.named_parameters()
            ):
                if param1.grad is None and param2.grad is None:
                    continue
                torch.testing.assert_close(
                    param1.grad,
                    param2.grad,
                    rtol=1e-4,
                    atol=1e-4,
                    # msg=f"Gradient mismatch for {name1} and {name2}",
                )
            print("Gradients are close!")
        except AssertionError as e:
            print("Gradients are not close!")
            print(e)


# ModelSimple no fast path

slow_model_simple = MambaSimple(dim=feature_dim, use_fast_path=False).to(x.device)
slow_model_simple = accelerator.prepare(slow_model_simple)
slow_model_simple.compile()

copy_weights(model_simple, slow_model_simple)


benchmark_model(
    "Simple Mamba no fast path", slow_model_simple, x, warmup_iters, test_iters
)

test_model("Simple Mamba no fast path", slow_model_simple, x)


# model_torch_new = MambaTorchNew(d_model=feature_dim).to(x.device)
# model_torch_new = accelerator.prepare(model_torch_new)
# # model_torch_new.compile()

# copy_weights(model_simple, model_torch_new)

# benchmark_model(
#     "Mamba Torch New", model_torch_new, x, warmup_iters, test_iters
# )

# test_model("Mamba Torch New", model_torch_new, x)

# Model Mamba Torch New

model = Mamba(
    input_dim=feature_dim,
    qk_dim=16,
    v_dim=2 * feature_dim,
).to(x.device)
model = accelerator.prepare(model)
model.compile()

model.in_proj.weight.data = model_simple.in_proj.weight.data.clone()
if model_simple.in_proj.bias is not None:
    model.in_proj.bias.data = model_simple.in_proj.bias.data.clone()
model.conv1d.weight.data = model_simple.conv1d.weight.data.clone()
if model_simple.conv1d.bias is not None:
    model.conv1d.bias.data = model_simple.conv1d.bias.data.clone()
model.dtqk_proj.weight.data = model_simple.x_proj.weight.data.clone()
if model_simple.x_proj.bias is not None:
    model.dtqk_proj.bias.data = model_simple.x_proj.bias.data.clone()
model.dt_proj.weight.data = model_simple.dt_proj.weight.data.clone()
if model_simple.dt_proj.bias is not None:
    model.dt_proj.bias.data = model_simple.dt_proj.bias.data.clone()
model.A_log.data = model_simple.A_log.data.clone()
model.D.data = model_simple.D.data.clone()
model.out_proj.weight.data = model_simple.out_proj.weight.data.clone()
if model_simple.out_proj.bias is not None:
    model.out_proj.bias.data = model_simple.out_proj.bias.data.clone()

benchmark_model("Mamba Torch", model, x, warmup_iters, test_iters)

test_model("Mamba Torch", model, x)
