import os

import pytest
import torch
import numpy as np
import shutil

from wavesAI.data.formal_language import FormalLanguageClassification
from wavesAI.model.mamba_optimized import MambaBlock, ModelArgs, MambaClassifier
from hypothesis import given, settings, strategies as st

@pytest.fixture
def model():
    VOCAB_SIZE = 5
    D_STATE = 6
    D_MODEL = 8
    torch.manual_seed(0)
    return MambaBlock(ModelArgs(D_MODEL,
                                      1,
                                      VOCAB_SIZE,
                                      D_STATE,
                                      1,
                                constant_B=True,
                                constant_C=True,
                                convolution1D=False,
                                bias=False, flags="aouc", diagnostic=True))


def test_cuda_optimized_alignment(model):
    inital_optimized_cuda = model.optimized_cuda
    initial_model_name = model.model_name

    BATCH_SIZE = 1
    SEQUENCE_LENGTH = 5
    maes = []
    for seed in range(1):
        torch.manual_seed(seed)

        input_x = torch.rand((BATCH_SIZE, SEQUENCE_LENGTH, model.args.d_model)) * 2 - 1
        input_z = torch.rand((BATCH_SIZE, SEQUENCE_LENGTH, model.args.d_model)) * 2 - 1

        assert torch.cuda.is_available(), "GPU not available. Optimized kernels will only work with GPU"

        input_x = input_x.to("cuda")
        input_z = input_z.to("cuda")
        model.to("cuda")

        print("----------------Ground Truth: ")
        model.optimized_cuda = False
        model.model_name = "mamba-aou"
        model1_out = model.ssm(input_x, input_z)
        # print(model1_out)
        model1_out.retain_grad()

        # test backward pass
        model.zero_grad()
        torch.sum(torch.abs(model1_out)).backward()

        print(f"y_grad_ref {model1_out.grad}")

        backward_gradients_ref = {}

        for n, p in model.named_parameters():
            if p.requires_grad and p.grad is not None:
                backward_gradients_ref[n] = p.grad.clone()

        print(backward_gradients_ref)

        print("----------------Optimized CUDA: ")
        model.optimized_cuda = True
        model.model_name = "mamba-aouc"
        model2_out = model.ssm(input_x, input_z)
        # print(model2_out)

        for key, arg in model.ssm_args.items():
            if isinstance(arg, torch.Tensor):
                assert arg.is_contiguous(), f"The argumennt {key} is not contiguous."

        print(model1_out, model2_out)
        assert torch.allclose(model1_out, model2_out, atol=0.1, rtol=0.1), "forward is not within the given tolerance (0.1)"
        model2_out.retain_grad()

        # test backward pass
        model.zero_grad()
        torch.sum(torch.abs(model2_out)).backward()

        for key, arg in model.ssm_args.items():
            print(key, arg)

        print(f"y_grad: {model2_out.grad}")

        for n, p in model.named_parameters():
            print(f"Comparing gradients for {n}")
            if n not in backward_gradients_ref:
                continue
            print(f"{p.grad} with {backward_gradients_ref[n]}")
            if p.requires_grad and p.grad is not None:
                assert torch.allclose(p.grad, backward_gradients_ref[n], atol=1e-4, rtol=1e-4), f"parameter {n} not close in gradients"

        maes.append(torch.mean(torch.abs((model1_out - model2_out))).cpu().item())
        print("MAE: ", maes[-1])

    print("Average error: ", np.mean(maes))

    # revert model
    model.optimized_cuda = inital_optimized_cuda
    model.model_name = initial_model_name


def test_cuda_convergence(model):
    model = MambaClassifier(d_model=16,
                            d_state=32,
                            n_layer=3,
                            vocab_size=2,
                            output_dim=2)
    data_module = FormalLanguageClassification(
        train_size=8,
        validation_size=4,
        test_size=8,
        min_train_length=1,
        max_train_length=5,
        test_bins=[(1, 50), (51, 100), (101, 150)],
        data_dir="./test_data",
        batch_size=4
    )
    try:
        import lightning as L
        from wavesAI.task.sequence_classification import SequenceClassification
        lmodel = SequenceClassification(network=model,
                                        optimizer=lambda x: torch.optim.AdamW(x, 0.001),
                                        scheduler=lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, eta_min=0,
                                                                                                       last_epoch=-1,
                                                                                                       T_max=1000),
                                        scheduler_interval="step", scheduler_frequency=1, scheduler_monitor="train_loss"
                                        )
        trainer = L.Trainer(enable_checkpointing=False, logger=False,
                            enable_progress_bar=True,
                            gradient_clip_val=1.0, max_epochs=100)

        trainer.fit(lmodel, datamodule=data_module)
        assert trainer.logged_metrics['train_accuracy'].item() >= 0.98, "Model not converged."
    finally:
        if os.path.exists(data_module.cache_dir):
            shutil.rmtree(data_module.cache_dir)
