import os

import pytest
import torch
import numpy as np
import shutil

from icecream import ic

from wavesAI.data.formal_language import FormalLanguageClassification
from wavesAI.model.mamba_optimized_v2 import MambaBlock, SSMauBlock, MambaSeq2Seq, ModelArgsMamba, ModelArgsSSMau
import wavesAI.model.mamba_optimized as mamba_optimized_v1
from hypothesis import given, settings, strategies as st

VOCAB_SIZE = 5
D_STATE = 6
D_MODEL = 8
N_LAYER = 1

@pytest.fixture
def model_mamba_ref():
    torch.manual_seed(0)
    return mamba_optimized_v1.MambaBlock(mamba_optimized_v1.ModelArgs(D_MODEL,
                                                                      1,
                                                                      VOCAB_SIZE,
                                                                      D_STATE,
                                                                      1,
                                                                      constant_B=False,
                                                                      constant_C=False,
                                                                      convolution1D=True,
                                                                      bias=False, flags="o", diagnostic=True))

@pytest.fixture
def model_mamba_v2():
    torch.manual_seed(0)
    return MambaBlock(ModelArgsMamba(D_MODEL,
                                     D_STATE,
                                     1))

@pytest.fixture
def model_ssmau_ref():
    torch.manual_seed(0)
    return mamba_optimized_v1.MambaBlock(mamba_optimized_v1.ModelArgs(D_MODEL,
                                                                      1,
                                                                      VOCAB_SIZE,
                                                                      D_STATE,
                                                                      1,
                                                                      dt_rank=D_STATE,
                                                                      constant_B=True,
                                                                      constant_C=True,
                                                                      convolution1D=True,
                                                                      bias=False, flags="aou",
                                                                      diagnostic=True))

@pytest.fixture
def model_ssmau_v2():
    torch.manual_seed(0)
    return SSMauBlock(ModelArgsSSMau(D_MODEL,
                                     D_STATE,
                                     1))

def test_ssmau_v1_v2_alignment(model_ssmau_ref, model_ssmau_v2):
    BATCH_SIZE = 1
    SEQUENCE_LENGTH = 5
    maes = []
    for seed in range(1):
        torch.manual_seed(seed)
        model = model_ssmau_ref
        input_x = torch.rand((BATCH_SIZE, SEQUENCE_LENGTH, model.args.d_model)) * 2 - 1
        input_z = torch.rand((BATCH_SIZE, SEQUENCE_LENGTH, model.args.d_model)) * 2 - 1

        assert torch.cuda.is_available(), "GPU not available. Optimized kernels will only work with GPU"

        input_x = input_x.to("cuda")
        input_z = input_z.to("cuda")
        model.to("cuda")

        print("----------------Ground Truth: ")
        model.optimized_cuda = False
        model.model_name = "mamba-aou"
        # model1_out = model.ssm(input_x, input_z)
        model1_out = model.forward(input_x)
        # print(model1_out)
        model1_out.retain_grad()

        # test backward pass
        model.zero_grad()
        torch.sum(torch.abs(model1_out)).backward()

        print(f"y_grad_ref {model1_out.grad}")

        backward_gradients_ref = {}

        for n, p in model.named_parameters():
            if p.requires_grad and p.grad is not None:
                backward_gradients_ref[n] = p.grad.clone()

        print(backward_gradients_ref)

        model.zero_grad()

        ###############################
        # Matching Parameters
        model_ssmau_v2.in_proj = model_ssmau_ref.in_proj
        model_ssmau_v2.conv1d = model_ssmau_ref.conv1d
        model_ssmau_v2.out_proj = model_ssmau_ref.out_proj
        model_ssmau_v2.x_proj = model_ssmau_ref.x_proj
        model_ssmau_v2.dt_proj = model_ssmau_ref.dt_proj
        model_ssmau_v2.D = model_ssmau_ref.D
        model_ssmau_v2.C = model_ssmau_ref.C
        model_ssmau_v2.B = model_ssmau_ref.B
        model_ssmau_v2.xA_proj = model_ssmau_ref.xA_proj
        ##############################

        print("----------------V2 ")
        model = model_ssmau_v2
        model.to("cuda")
        # model2_out = model.ssm_mamba(input_x, input_z)
        model2_out = model.forward(input_x)
        # print(model2_out)

        for key, arg in model.ssm_args.items():
            if isinstance(arg, torch.Tensor):
                assert arg.is_contiguous(), f"The argumennt {key} is not contiguous."

        # ic(model1_out)
        # ic(model2_out)
        # ic(model_mamba_ref.args)
        # ic(model_mamba_v2.args)
        assert torch.allclose(model1_out, model2_out, atol=0.1,
                              rtol=0.1), "forward is not within the given tolerance (0.1)"
        model2_out.retain_grad()

        # test backward pass
        model.zero_grad()
        torch.sum(torch.abs(model2_out)).backward()

        # for key, arg in model.ssm_args.items():
        #     print(key, arg)

        print(f"y_grad: {model2_out.grad}")

        for n, p in model.named_parameters():
            print(f"Comparing gradients for {n}")
            if n not in backward_gradients_ref:
                continue
            print(f"{p.grad} with {backward_gradients_ref[n]}")
            if p.requires_grad and p.grad is not None:
                assert torch.allclose(p.grad, backward_gradients_ref[n], atol=1e-4,
                                      rtol=1e-4), f"parameter {n} not close in gradients"

        maes.append(torch.mean(torch.abs((model1_out - model2_out))).cpu().item())
        print("MAE: ", maes[-1])

    print("Average error: ", np.mean(maes))


def test_mamba_v1_v2_alignment(model_mamba_ref, model_mamba_v2):
    BATCH_SIZE = 1
    SEQUENCE_LENGTH = 5
    maes = []
    for seed in range(1):
        torch.manual_seed(seed)
        model = model_mamba_ref
        input_x = torch.rand((BATCH_SIZE, SEQUENCE_LENGTH, model.args.d_model)) * 2 - 1
        input_z = torch.rand((BATCH_SIZE, SEQUENCE_LENGTH, model.args.d_model)) * 2 - 1

        assert torch.cuda.is_available(), "GPU not available. Optimized kernels will only work with GPU"

        input_x = input_x.to("cuda")
        input_z = input_z.to("cuda")
        model.to("cuda")

        print("----------------Ground Truth: ")
        model.optimized_cuda = False
        model.model_name = "mamba-aou"
        # model1_out = model.ssm(input_x, input_z)
        model1_out = model.forward(input_x)
        # print(model1_out)
        model1_out.retain_grad()

        # test backward pass
        model.zero_grad()
        torch.sum(torch.abs(model1_out)).backward()

        print(f"y_grad_ref {model1_out.grad}")

        backward_gradients_ref = {}

        for n, p in model.named_parameters():
            if p.requires_grad and p.grad is not None:
                backward_gradients_ref[n] = p.grad.clone()

        print(backward_gradients_ref)

        model.zero_grad()

        ###############################
        # Matching Parameters
        model_mamba_v2.in_proj = model_mamba_ref.in_proj
        model_mamba_v2.conv1d = model_mamba_ref.conv1d
        model_mamba_v2.out_proj = model_mamba_ref.out_proj
        model_mamba_v2.x_proj = model_mamba_ref.x_proj
        model_mamba_v2.dt_proj = model_mamba_ref.dt_proj
        model_mamba_v2.D = model_mamba_ref.D
        model_mamba_v2.A_log = model_mamba_ref.A_log
        ##############################

        print("----------------V2 ")
        model = model_mamba_v2
        model.to("cuda")
        # model2_out = model.ssm_mamba(input_x, input_z)
        model2_out = model.forward(input_x)
        # print(model2_out)

        # for key, arg in model.ssm_args.items():
        #     if isinstance(arg, torch.Tensor):
        #         assert arg.is_contiguous(), f"The argumennt {key} is not contiguous."

        # ic(model1_out)
        # ic(model2_out)
        # ic(model_mamba_ref.args)
        # ic(model_mamba_v2.args)
        assert torch.allclose(model1_out, model2_out, atol=0.1,
                              rtol=0.1), "forward is not within the given tolerance (0.1)"
        model2_out.retain_grad()

        # test backward pass
        model.zero_grad()
        torch.sum(torch.abs(model2_out)).backward()

        # for key, arg in model.ssm_args.items():
        #     print(key, arg)

        print(f"y_grad: {model2_out.grad}")

        for n, p in model.named_parameters():
            print(f"Comparing gradients for {n}")
            if n not in backward_gradients_ref:
                continue
            print(f"{p.grad} with {backward_gradients_ref[n]}")
            if p.requires_grad and p.grad is not None:
                assert torch.allclose(p.grad, backward_gradients_ref[n], atol=1e-4,
                                      rtol=1e-4), f"parameter {n} not close in gradients"

        maes.append(torch.mean(torch.abs((model1_out - model2_out))).cpu().item())
        print("MAE: ", maes[-1])

    print("Average error: ", np.mean(maes))


# def test_cuda_optimized_alignment(model):
#     inital_optimized_cuda = model.optimized_cuda
#     initial_model_name = model.model_name
#
#     BATCH_SIZE = 1
#     SEQUENCE_LENGTH = 5
#     maes = []
#     for seed in range(1):
#         torch.manual_seed(seed)
#
#         input_x = torch.rand((BATCH_SIZE, SEQUENCE_LENGTH, model.args.d_model)) * 2 - 1
#         input_z = torch.rand((BATCH_SIZE, SEQUENCE_LENGTH, model.args.d_model)) * 2 - 1
#
#         assert torch.cuda.is_available(), "GPU not available. Optimized kernels will only work with GPU"
#
#         input_x = input_x.to("cuda")
#         input_z = input_z.to("cuda")
#         model.to("cuda")
#
#         print("----------------Ground Truth: ")
#         model.optimized_cuda = False
#         model.model_name = "mamba-aou"
#         model1_out = model.ssm(input_x, input_z)
#         # print(model1_out)
#         model1_out.retain_grad()
#
#         # test backward pass
#         model.zero_grad()
#         torch.sum(torch.abs(model1_out)).backward()
#
#         print(f"y_grad_ref {model1_out.grad}")
#
#         backward_gradients_ref = {}
#
#         for n, p in model.named_parameters():
#             if p.requires_grad and p.grad is not None:
#                 backward_gradients_ref[n] = p.grad.clone()
#
#         print(backward_gradients_ref)
#
#         print("----------------Optimized CUDA: ")
#         model.optimized_cuda = True
#         model.model_name = "mamba-aouc"
#         model2_out = model.ssm(input_x, input_z)
#         # print(model2_out)
#
#         for key, arg in model.ssm_args.items():
#             if isinstance(arg, torch.Tensor):
#                 assert arg.is_contiguous(), f"The argumennt {key} is not contiguous."
#
#         print(model1_out, model2_out)
#         assert torch.allclose(model1_out, model2_out, atol=0.1,
#                               rtol=0.1), "forward is not within the given tolerance (0.1)"
#         model2_out.retain_grad()
#
#         # test backward pass
#         model.zero_grad()
#         torch.sum(torch.abs(model2_out)).backward()
#
#         for key, arg in model.ssm_args.items():
#             print(key, arg)
#
#         print(f"y_grad: {model2_out.grad}")
#
#         for n, p in model.named_parameters():
#             print(f"Comparing gradients for {n}")
#             if n not in backward_gradients_ref:
#                 continue
#             print(f"{p.grad} with {backward_gradients_ref[n]}")
#             if p.requires_grad and p.grad is not None:
#                 assert torch.allclose(p.grad, backward_gradients_ref[n], atol=1e-4,
#                                       rtol=1e-4), f"parameter {n} not close in gradients"
#
#         maes.append(torch.mean(torch.abs((model1_out - model2_out))).cpu().item())
#         print("MAE: ", maes[-1])
#
#     print("Average error: ", np.mean(maes))
#
#     # revert model
#     model.optimized_cuda = inital_optimized_cuda
#     model.model_name = initial_model_name

# def test_cuda_convergence(model):
#     model = MambaClassifier(d_model=16,
#                             d_state=32,
#                             n_layer=3,
#                             vocab_size=2,
#                             output_dim=2)
#     data_module = FormalLanguageClassification(
#         train_size=8,
#         validation_size=4,
#         test_size=8,
#         min_train_length=1,
#         max_train_length=5,
#         test_bins=[(1, 50), (51, 100), (101, 150)],
#         data_dir="./test_data",
#         batch_size=4
#     )
#     try:
#         import lightning as L
#         from wavesAI.task.sequence_classification import SequenceClassification
#         lmodel = SequenceClassification(network=model,
#                                         optimizer=lambda x: torch.optim.AdamW(x, 0.001),
#                                         scheduler=lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, eta_min=0,
#                                                                                                        last_epoch=-1,
#                                                                                                        T_max=1000),
#                                         scheduler_interval="step", scheduler_frequency=1, scheduler_monitor="train_loss"
#                                         )
#         trainer = L.Trainer(enable_checkpointing=False, logger=False,
#                             enable_progress_bar=True,
#                             gradient_clip_val=1.0, max_epochs=100)
#
#         trainer.fit(lmodel, datamodule=data_module)
#         assert trainer.logged_metrics['train_accuracy'].item() >= 0.98, "Model not converged."
#     finally:
#         if os.path.exists(data_module.cache_dir):
#             shutil.rmtree(data_module.cache_dir)
