import os
import pytest
import torch
import shutil
import torch
import warnings
from wavesAI.data.formal_language import FormalLanguageClassification
from wavesAI.model.mamba_optimized import MambaSeq2Seq, ModelArgs
from wavesAI.model.elman_rnn import ElmanRNNSeq2Seq
from wavesAI.data.formal_language import FormalLanguageSeq2Seq
from wavesAI.data.formal_language import FormalLanguageSequence
from wavesAI.task.sequence_modeling import SequenceModeling
from lightning import Trainer

from lightning.pytorch.callbacks import RichProgressBar

from rich import pretty
pretty.install()

VOCAB_SIZE=8

@pytest.fixture
def rnn_module():
    """
    Create an Elman RNN that is known to work with the repeat copy task

    :return: wavesAI.model.elman_rnn.ElmanRNN
    """
    return ElmanRNNSeq2Seq(VOCAB_SIZE, 64, VOCAB_SIZE, True)


@pytest.fixture
def mamba_module():
    """
    Create an Elman RNN that is known to work with the repeat copy task

    :return: wavesAI.model.elman_rnn.ElmanRNN
    """
    # return MambaSeq2Seq(32, 2, VOCAB_SIZE, VOCAB_SIZE, flags='o', bias=True)
    return MambaSeq2Seq(16, d_state=16, dt_rank=16, n_layer=2, vocab_size=VOCAB_SIZE, output_vocab_size=VOCAB_SIZE, flags='aou',
                         convolution1D=False, constant_B=True, constant_C=True, x_factor=.001, x_bias_factor=1)


@pytest.fixture
def data_module_seq2seq():
    """
    Create a small Repeat Copy task known to be perfectly trained with Elman RNNs

    :return: wavesAI.data.formal_language.FormalLanguageSeq2Seq
    """
    return FormalLanguageSeq2Seq(data_dir="./test_data_tmp",
                                 train_size = 32,
                                 validation_size = 32,
                                 test_size = 64,
                                 min_train_length = 5,
                                 max_train_length = 5,
                                 test_type = "SR",
                                 test_bins = [(5, 5), (5, 5), (5, 5)],
                                 task_name = "repetition",
                                 batch_size = 32,
                                 num_workers = 1,
                                 requested_vocab_size = VOCAB_SIZE,
                                 pin_memory = False)


@pytest.fixture
def data_module_seq():
    """
    Create a small Repeat Copy task known to be perfectly trained with Elman RNNs

    :return: wavesAI.data.formal_language.FormalLanguageSeq2Seq
    """
    return FormalLanguageSequence(data_dir="./test_data_tmp_seq",
                                 train_size = 32,
                                 validation_size = 32,
                                 test_size = 64,
                                 min_train_length = 5,
                                 max_train_length = 5,
                                 test_type = "SR",
                                 test_bins = [(5, 5), (5, 5), (5, 5)],
                                 task_name = "repetition",
                                 batch_size = 32,
                                 num_workers = 1,
                                 requested_vocab_size = VOCAB_SIZE,
                                 pin_memory = False)


@pytest.fixture
def sequence_task(rnn_module):
    return SequenceModeling(rnn_module,
                            lambda x: torch.optim.AdamW(x, lr=1e-3, weight_decay=0.0),
                            lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, T_max=5000),
                            scheduler_interval="epoch")


@pytest.fixture
def sequence_task_mamba(mamba_module):
    return SequenceModeling(mamba_module,
                            lambda x: torch.optim.AdamW(x, lr=1e-3, weight_decay=0.0),
                            lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, T_max=5000),
                            scheduler_interval="epoch")


@pytest.fixture(autouse=True)
def cleanup(data_module_seq2seq):
    """Clean up the test data directory after all test complete."""
    yield
    if os.path.exists(data_module_seq2seq.data_dir):
        shutil.rmtree(data_module_seq2seq.data_dir)

    if os.path.exists("test_logs_tmp"):
        shutil.rmtree("test_logs_tmp")


@pytest.fixture(autouse=True)
def cleanup_seq(data_module_seq):
    """Clean up the test data directory after all test complete."""
    yield
    if os.path.exists(data_module_seq.data_dir):
        shutil.rmtree(data_module_seq.data_dir)

    if os.path.exists("test_logs_tmp_seq"):
        shutil.rmtree("test_logs_tmp_seq")


def test_trainer_compatibility(sequence_task, data_module_seq2seq):
    # setup data if not setup
    data_module_seq2seq.prepare_data()
    data_module_seq2seq.setup()

    trainer = Trainer(fast_dev_run=7)

    trainer.fit(sequence_task, datamodule=data_module_seq2seq)


def test_overfitting_rnn_seq2seq(sequence_task, data_module_seq2seq):
    data_module_seq2seq.prepare_data()
    data_module_seq2seq.setup()

    trainer = Trainer(limit_train_batches=1, limit_val_batches=1, default_root_dir="test_logs_tmp", max_epochs=200,
                      gradient_clip_val=0.5, gradient_clip_algorithm="value", detect_anomaly=True)

    trainer.fit(sequence_task, datamodule=data_module_seq2seq)

    # print(trainer.logged_metrics)
    assert trainer.logged_metrics["train_accuracy"].item() >= 0.98, "trainer and model did not satisfy the overfitting condition"
