from wavesAI.model.mamba_optimized import MambaClassifier

from lightning.data import StreamingDataset

import numpy as np
import torch

from timeit import default_timer as timer  # this is the library that does timing
import matplotlib.pyplot as plt
from tqdm import tqdm

plt.rcParams["font.family"] = "serif"
# plt.rcParams["font.serif"] = "Computer Modern"
plt.rcParams["font.size"] = 14


import contextlib
import io
import sys


@contextlib.contextmanager
def nostdout():
    save_stdout = sys.stdout
    sys.stdout = io.BytesIO()
    yield
    sys.stdout = save_stdout


def evaluate_seq_lengths(net, options, seq_lengths=None):
    """
    params: model: Model that takes in a sequence input
    params: data: StreamingDataset with a sequence input
    """
    if seq_lengths is None:
        seq_lengths = list(range(10, 50, 10))

    # print(data[0:100])
    # x, _ = data[0]
    # x = torch.LongTensor(x)
    # x = torch.unsqueeze(x, 0)  # add dimension for batch size
    # x = x.to(torch.device("cuda"))
    x = torch.randint(options.vocab_size, size=(options.batch_size, np.max(seq_lengths))).to("cuda")

    time_taken = []
    memory_used = []

    for seq_length in tqdm(seq_lengths, leave=False):
        x_in = x[:, :seq_length]

        torch.cuda.reset_peak_memory_stats()
        # with nostdout():

        model = net["classname"](**net["init_params"]).to("cuda")

        start = timer()  # time forward and backward
        y = model(x_in)
        if not options.no_backward:
            y.sum().backward()  # backward
        end = timer()

        time_taken.append(end - start)
        memory_used.append(torch.cuda.max_memory_allocated())

    return time_taken, memory_used


def run_evaluation(models, options):
    step = int(options.sequence_length / 20)
    seq_lengths = list(range(10, options.sequence_length, step))

    # load the copy streaming dataset for evaluation
    time_dict = {}
    memory_dict = {}

    for key, model in models.items():
        # model.eval()
        print("Running ", key)
        # if ("mamOG" not in key) and ("mambaOp" not in key):
        #     model.double()  # wierd typing issues from Pytorch
        # else:
        #     print("model not converted to double")
        # device = torch.device("cuda")
        # model.to(device)

        time_taken = []
        memory_taken = []
        for trial in tqdm(range(options.runs)):
            # with torch.no_grad():
            time_elapsed, memory_used = evaluate_seq_lengths(model, options, seq_lengths=seq_lengths)
            time_taken.append(time_elapsed)
            memory_taken.append(memory_used)

        time_dict[key] = time_taken
        memory_dict[key] = memory_taken

        torch.cuda.empty_cache()

    return seq_lengths, time_dict, memory_dict


def treatment_1(model):
    model.args.x_factor = 0
    model.args.x_bias_factor = 0
    return model

def treatment_2(model):
    model.args.x_factor = 0
    return model

def treatment_3(model):
    model.args.x_bias_factor = 0
    return model

def treatment_4(model):
    model.args.binary = True
    return model

def treatment_5(model):
    model.args.binary = True
    model.args.no_graded_rotation = True
    return model

def treatment_6(model):
    model.args.convolution1D = False
    return model

def treatment_7(model):
    model.args.unit_z = True
    return model

def treatment_8(model):
    model.args.unit_B = True
    return model

def treatment_9(model):
    model.args.unit_C = True
    return model

def treatment_10(model):
    model.args.zero_D = True
    return model

def treatment_11(model):
    model.args.unit_dt = True
    return model


## tests that can be used to check for compatibility between modules

def check_seqtask_dataloader_compatibility(dataloader: torch.utils.data.DataLoader):
    """
    Utility function to check if a dataloader is compatible with wavesAI.task.sequence_modeling

    :param dataloader: (torch.utils.data.DataLoader)
    """
    batch = next(iter(dataloader))

    x, y, options = batch
    assert len(batch) >= 2, "atleast x and y should be present in the batch"
    assert x.dtype == torch.long, "x should be a long tensor"
    assert y.dtype == torch.long, "y should be a long tensor"
    assert len(x.shape) == 2, "x should have 2 dimensions (batch x sequence length)"
    assert x.shape == y.shape, "x and y should have the same shape"

    if "masks" in options:
        assert x.shape == options["masks"].shape, "mask should have the same shape as x and y"


def check_seqtask_datamodule_compatiblity(data_module):
    """
    Utility function to check if a data_module is compatible with wavesAI.task.sequence_modeling

    :param data_module: obj
    """
    assert hasattr(data_module, "output_vocab_size"), "output_vocab_size should be implemented in the class"
    check_seqtask_dataloader_compatibility(data_module.train_dataloader())
    check_seqtask_dataloader_compatibility(data_module.val_dataloader())
    check_seqtask_dataloader_compatibility(data_module.test_dataloader())
