import os
import pytest
import torch
import shutil
import torch
import warnings
from wavesAI.data.formal_language import FormalLanguageStreaming
from wavesAI.model.mamba_optimized_v2 import MambaSeq2Seq
from wavesAI.model.elman_rnn import ElmanRNNSeq2Seq
from wavesAI.task.sequence_modeling import SequenceModeling
from lightning import Trainer

from lightning.pytorch.callbacks import RichProgressBar

from rich import pretty
pretty.install()

VOCAB_SIZE=64
OUTPUT_VOCAB_SIZE=64


@pytest.fixture
def ssmau_module():
    """
    Create an Elman RNN that is known to work with the repeat copy task

    :return: wavesAI.model.elman_rnn.ElmanRNN
    """
    # return MambaSeq2Seq(32, 2, VOCAB_SIZE, VOCAB_SIZE, flags='o', bias=True)
    return MambaSeq2Seq(8,
                        VOCAB_SIZE,
                        OUTPUT_VOCAB_SIZE,
                        d_state=8, dt_rank='auto',
                        layers="m|a")


@pytest.fixture
def data_module_seq():
    """
    Create a small Repeat Copy task known to be perfectly trained with Elman RNNs

    :return: wavesAI.data.formal_language.FormalLanguageSeq2Seq
    """
    return FormalLanguageStreaming(2,
                                       10,
                                       20,
                                       11,
                                       20,
                                       task_name="repetition",
                                       batch_size=1,
                                       num_workers=1,
                                       vocab_size=VOCAB_SIZE)


@pytest.fixture
def sequence_task_ssmau(ssmau_module):
    return SequenceModeling(ssmau_module,
                            lambda x: torch.optim.AdamW(x, lr=1e-3, weight_decay=0.0),
                            lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, T_max=5000),
                            scheduler_interval="epoch")


@pytest.fixture(autouse=True)
def cleanup():
    """Clean up the test data directory after all test complete."""
    yield

    if os.path.exists("test_logs_tmp"):
        shutil.rmtree("test_logs_tmp")


# WARNING THIS WILL TAKE TIME
def test_overfitting_ssmau_flseq2seq(sequence_task_ssmau, data_module_seq):
    data_module_seq.prepare_data()
    data_module_seq.setup()

    trainer = Trainer(limit_train_batches=1, limit_val_batches=1, default_root_dir="test_logs_tmp", max_epochs=250,
                      detect_anomaly=True, callbacks=[RichProgressBar()])

    trainer.fit(sequence_task_ssmau, datamodule=data_module_seq)

    # print(trainer.logged_metrics)
    assert trainer.logged_metrics["train_accuracy"].item() >= 0.98, "trainer and model did not satisfy the overfitting condition"
