import os
import pytest
import torch
import shutil
import torch
import warnings
from wavesAI.data.lra import ImageClassification
from wavesAI.model.mamba_optimized_v2 import MambaClassifier
from wavesAI.model.elman_rnn import ElmanRNNSeq2Seq
from wavesAI.task.sequence_classification import SequenceClassification
from lightning import Trainer

from lightning.pytorch.callbacks import RichProgressBar

from rich import pretty
pretty.install()

VOCAB_SIZE=256
OUTPUT_VOCAB_SIZE=10

@pytest.fixture
def rnn_module():
    """
    Create an Elman RNN that is known to work with the repeat copy task

    :return: wavesAI.model.elman_rnn.ElmanRNN
    """
    return ElmanRNNSeq2Seq(VOCAB_SIZE, 64, OUTPUT_VOCAB_SIZE, True)


@pytest.fixture
def ssmau_module():
    """
    Create an Elman RNN that is known to work with the repeat copy task

    :return: wavesAI.model.elman_rnn.ElmanRNN
    """
    # return MambaSeq2Seq(32, 2, VOCAB_SIZE, VOCAB_SIZE, flags='o', bias=True)
    return MambaClassifier(8, output_dim=OUTPUT_VOCAB_SIZE, d_state=8,
                           layers="a|m", vocab_size=VOCAB_SIZE, mode="last")


@pytest.fixture
def data_module_seq():
    """
    Create a small Repeat Copy task known to be perfectly trained with Elman RNNs

    :return: wavesAI.data.formal_language.FormalLanguageSeq2Seq
    """
    data_root = Path(os.environ["DATA_ROOT"])
    return ImageClassification(data_root / "cifar10", batch_size=64)


@pytest.fixture
def sequence_task(rnn_module):
    return SequenceClassification(rnn_module,
                            lambda x: torch.optim.AdamW(x, lr=1e-3, weight_decay=0.0),
                            lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, T_max=5000),
                            scheduler_interval="epoch")


@pytest.fixture
def sequence_task_ssmau(ssmau_module):
    return SequenceClassification(ssmau_module,
                            lambda x: torch.optim.AdamW(x, lr=1e-3, weight_decay=0.0),
                            lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, T_max=5000),
                            scheduler_interval="epoch")


@pytest.fixture(autouse=True)
def cleanup():
    """Clean up the test data directory after all test complete."""
    yield

    if os.path.exists("test_logs_tmp"):
        shutil.rmtree("test_logs_tmp")


# WARNING THIS WILL TAKE TIME
def test_overfitting_cifar10(sequence_task_ssmau, data_module_seq):
    data_module_seq.prepare_data()
    data_module_seq.setup()

    trainer = Trainer(limit_train_batches=1, limit_val_batches=1, default_root_dir="test_logs_tmp", max_epochs=500,
                      detect_anomaly=True, callbacks=[RichProgressBar()])

    trainer.fit(sequence_task_ssmau, datamodule=data_module_seq)

    # print(trainer.logged_metrics)
    assert trainer.logged_metrics["train_accuracy"].item() >= 0.98, "trainer and model did not satisfy the overfitting condition"
