import os
import pytest
import torch
import shutil
import torch
import warnings
from wavesAI.data.lra import Pathfinder
from wavesAI.model.mamba_optimized_v2 import MambaClassifierSansEmbedding
from wavesAI.model.elman_rnn import ElmanRNNSeq2Seq
from wavesAI.task.sequence_classification import SequenceClassification
from lightning import Trainer

from lightning.pytorch.callbacks import RichProgressBar

from rich import pretty
pretty.install()

INPUT_DIM = 1
OUTPUT_DIM=2


@pytest.fixture
def ssmau_module():
    """
    Create an Elman RNN that is known to work with the repeat copy task

    :return: wavesAI.model.elman_rnn.ElmanRNN
    """
    # return MambaSeq2Seq(32, 2, VOCAB_SIZE, VOCAB_SIZE, flags='o', bias=True)

    # return MambaClassifierSansEmbedding(16, output_dim=OUTPUT_DIM, input_dim=INPUT_DIM, d_state=16,
    #                                     n_layer=2, flags='o')

    return MambaClassifierSansEmbedding(16, output_dim=OUTPUT_DIM, input_dim=INPUT_DIM, d_state=1,
                        dt_rank=1, layers="m|m|m|m|m|m|m|m|m|m|m|m")


@pytest.fixture
def data_module_seq():
    """
    Create a small Repeat Copy task known to be perfectly trained with Elman RNNs

    :return: wavesAI.data.formal_language.FormalLanguageSeq2Seq
    """
    data_root = Path(os.environ["DATA_ROOT"])
    return Pathfinder(data_root / "lra_release", batch_size=16)


# @pytest.fixture
# def sequence_task(rnn_module):
#     return SequenceClassification(rnn_module,
#                             lambda x: torch.optim.AdamW(x, lr=1e-3, weight_decay=0.0),
#                             lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, T_max=5000),
#                             scheduler_interval="epoch")


@pytest.fixture
def sequence_task_ssmau(ssmau_module):
    return SequenceClassification(ssmau_module,
                            lambda x: torch.optim.AdamW(x, lr=1e-3, weight_decay=0.0),
                            lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, T_max=5000),
                            scheduler_interval="epoch")


@pytest.fixture(autouse=True)
def cleanup():
    """Clean up the test data directory after all test complete."""
    yield

    if os.path.exists("test_logs_tmp"):
        shutil.rmtree("test_logs_tmp")


# WARNING THIS WILL TAKE TIME
def test_overfitting_ssmau_pathfinder(sequence_task_ssmau, data_module_seq):
    data_module_seq.prepare_data()
    data_module_seq.setup()

    trainer = Trainer(default_root_dir="test_logs_tmp", max_epochs=150,
                      limit_train_batches=1, limit_val_batches=1,
                      # detect_anomaly=True,
                      callbacks=[RichProgressBar()])

    trainer.fit(sequence_task_ssmau, datamodule=data_module_seq)

    # print(trainer.logged_metrics)
    assert trainer.logged_metrics["train_accuracy"].item() >= 0.98, "trainer and model did not satisfy the overfitting condition"
