from wavesAI.model.mamba_optimized import MambaClassifier

from lightning.data import StreamingDataset

import numpy as np
import torch
import os

from timeit import default_timer as timer  # this is the library that does timing
import matplotlib.pyplot as plt
from tqdm import tqdm

plt.rcParams["font.family"] = "serif"
# plt.rcParams["font.serif"] = "Computer Modern"
plt.rcParams["font.size"] = 14


import contextlib
import io
import sys
import argparse


@contextlib.contextmanager
def nostdout():
    save_stdout = sys.stdout
    sys.stdout = io.BytesIO()
    yield
    sys.stdout = save_stdout


def evaluate_seq_lengths(net, options, seq_lengths=None):
    """
    params: model: Model that takes in a sequence input
    params: data: StreamingDataset with a sequence input
    """
    if seq_lengths is None:
        seq_lengths = list(range(10, 50, 10))

    # print(data[0:100])
    # x, _ = data[0]
    # x = torch.LongTensor(x)
    # x = torch.unsqueeze(x, 0)  # add dimension for batch size
    # x = x.to(torch.device("cuda"))
    x = torch.randint(options.vocab_size, size=(options.batch_size, np.max(seq_lengths))).to("cuda")

    time_taken = []
    memory_used = []

    for seq_length in tqdm(seq_lengths, leave=False):
        x_in = x[:, :seq_length]

        torch.cuda.reset_peak_memory_stats()
        # with nostdout():
        model = net["classname"](verbose=False, **net["init_params"]).to("cuda")

        start = timer()  # time forward and backward
        y = model(x_in)
        if not options.no_backward:
            y.sum().backward()  # backward
        end = timer()

        time_taken.append(end - start)
        memory_used.append(torch.cuda.max_memory_allocated())

    return time_taken, memory_used


def run_evaluation(models, options):
    step = int(options.sequence_length / 20)
    seq_lengths = list(range(10, options.sequence_length, step))

    # load the copy streaming dataset for evaluation
    time_dict = {}
    memory_dict = {}

    for key, model in models.items():
        # model.eval()
        print("Running ", key)
        # if ("mamOG" not in key) and ("mambaOp" not in key):
        #     model.double()  # wierd typing issues from Pytorch
        # else:
        #     print("model not converted to double")
        # device = torch.device("cuda")
        # model.to(device)

        time_taken = []
        memory_taken = []
        for trial in tqdm(range(options.runs)):
            # with torch.no_grad():
            time_elapsed, memory_used = evaluate_seq_lengths(model, options, seq_lengths=seq_lengths)
            time_taken.append(time_elapsed)
            memory_taken.append(memory_used)

        time_dict[key] = time_taken
        memory_dict[key] = memory_taken

        torch.cuda.empty_cache()

    return seq_lengths, time_dict, memory_dict


# plotting utility

def plot_results(x, ys):
    for key, y in ys.items():
        y = np.array(y)
        y_mean = np.mean(y, axis=0)
        y_std = np.std(y, axis=0)
        y_up = y_mean + y_std
        y_down = y_mean - y_std

        # print(y_std)
        plt.plot(x, np.array(y_mean), label=key)
        plt.fill_between(x, y_mean - y_std, y_mean + y_std, alpha=0.5)
        # plt.errorbar(x, y_mean, yerr=[ y_up, y_down ], label=key)

if  __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ## DEBUG
    parser.add_argument("-b", "--batch-size", type=int, default=16)
    parser.add_argument("-l", "--sequence_length", type=int, default=2000)
    parser.add_argument("-d", "--d_model", type=int, default=32)
    parser.add_argument("-n", "--d_state", type=int, default=64)
    parser.add_argument("-v", "--vocab_size", type=int, default=16)
    parser.add_argument("-o", "--output_dim", type=int, default=2)
    parser.add_argument("-nl", "--num_layers", type=int, default=1)
    parser.add_argument("-r", "--runs", type=int, default=30)
    parser.add_argument("-nb", "--no_backward", action="store_true")

    options = parser.parse_args()

    # Testing models
    models = {
        "mamOG": {
            "classname": MambaClassifier,
            "init_params": { "d_model": options.d_model, "d_state": options.d_state, "output_dim": options.output_dim, "n_layer": options.num_layers, "expand": 1,
                "vocab_size": options.vocab_size, "flags": "o" },
        },
        "mamAOCU": {
            "classname": MambaClassifier,
            "init_params": {"d_model": options.d_model, "d_state": options.d_state, "output_dim": options.output_dim, "n_layer": options.num_layers, "expand": 1,
                "vocab_size": options.vocab_size, "flags": "aocu"},
        },
        "mamAOC": {
            "classname": MambaClassifier,
            "init_params": {"d_model": options.d_model, "d_state": options.d_state, "output_dim": options.output_dim,
                            "n_layer": options.num_layers, "expand": 1,
                            "vocab_size": options.vocab_size, "flags": "aoc"},
        }
    }

    seq_lengths, time_taken, memory_used = run_evaluation(models, options)

    time_taken = {key: (10000000 / (60 * 60)) * np.array(val) for key, val in
                  time_taken.items()}  # extrapolate seconds to hours
    memory_used = {key: np.array(val) / 1e6 for key, val in memory_used.items()}  # convert bytes to MB

    plot_results(np.array(seq_lengths), time_taken)
    plt.xlabel("seq length")
    plt.ylabel("time taken for 10M steps (in hours)")
    plt.legend()
    plt.yscale("log")

    plt.gca().legend(loc="upper left")
    plt.tight_layout()
    # plt.show()
    plt.savefig("allComparisonsOptimizedTime.png")

    plt.clf()
    plot_results(np.array(seq_lengths), memory_used)
    plt.xlabel("seq length")
    plt.ylabel("runtime memory used (in MB)")
    plt.legend()
    plt.yscale("log")

    plt.gca().legend(loc="upper left")
    plt.tight_layout()
    # plt.show()
    plt.savefig("allComparisonsOptimizedMemory.png")
