import sys

from lightning import seed_everything

from wavesAI.data.lra import ListOps
from wavesAI.task.sequence_classification import SequenceClassification
from wavesAI.model.mamba_optimized import MambaClassifier

from timeit import default_timer as timer  # this is the library that does timing

import torch

import lightning as L
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.callbacks import StochasticWeightAveraging

import argparse

## This will be a benchmark that uses the lightning trainer


def evaluate_seq_length(options, run_tuner=False, overfit=False, no_compile=False):
    """
    This function will evaluate how much time it takes to run an epoch with the provided sequence length.

    :param options:

    :return: Tuple[float,float] (time (in s), memory) Time to run 1 epoch of the LRA data with this net and the maximum memory runtime memory used
    """

    ## Intialize the listOps class
    data = ListOps(**options['data'])

    ## initialize the pytorch lightning module and add the mambaclassifier and an optimizer to it
    if no_compile:
        torch.compiler.disable(fn=MambaClassifier, recursive=True)
    net = MambaClassifier(vocab_size=data.vocab_size, **options['net'])
    lmodel = SequenceClassification(network=net,
                                    optimizer=lambda x: torch.optim.AdamW(x, **options['optimizer']),
                                    scheduler=lambda x: torch.optim.lr_scheduler.CosineAnnealingLR(x, eta_min=0, last_epoch=-1, **options['scheduler']),
                                    # scheduler=lambda x: torch.optim.lr_scheduler.ChainedScheduler([
                                    #     torch.optim.lr_scheduler.ExponentialLR(x, gamma=0.999974416),
                                    #     torch.optim.lr_scheduler.CosineAnnealingLR(x, eta_min=0, last_epoch=-1, **options['scheduler']) ]),
                                    scheduler_interval="step", scheduler_frequency=1, scheduler_monitor="train_loss"
                                    # scheduler=lambda x: torch.optim.lr_scheduler.MultiStepLR(optimizer=x,
                                    #                                                          milestones=[1],
                                    #                                                          gamma=0.92,
                                    #                                                          verbose=True)
                                    )

    additional_trainer_args = {}
    if overfit:
        additional_trainer_args['limit_val_batches'] = 0
        additional_trainer_args['limit_train_batches'] = 10
    else:
        pass

    trainer = L.Trainer(enable_checkpointing=False, logger=False,
                        enable_progress_bar=True,
                        gradient_clip_val=1.0,
                        **additional_trainer_args,
                        **options["trainer"]
                        )

    if run_tuner:
        tuner = Tuner(trainer)

        lr_finder = tuner.lr_find(lmodel, datamodule=data, min_lr=3e-4, num_training=500)

        # Results can be found in
        print(lr_finder.results)

        # Plot with
        fig = lr_finder.plot(suggest=True)
        fig.savefig("lr_finder.png")

        print("Suggested learning rate: ", lr_finder.suggestion())
        print("Exiting because tuner is run")
        sys.exit()

    torch.cuda.reset_peak_memory_stats()
    start = timer()
    trainer.fit(lmodel, datamodule=data)
    end = timer()

    return end - start, torch.cuda.max_memory_allocated()

if  __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ## DEBUG
    parser.add_argument("--data_dir", type=str, default="./listops-1000")
    parser.add_argument("-b", "--batch_size", type=int, default=32)
    parser.add_argument("-l", "--sequence_length", type=int, default=2048)
    parser.add_argument("-d", "--d_model", type=int, default=32)
    parser.add_argument("-n", "--d_state", type=int, default=8)
    parser.add_argument("-nl", "--num_layers", type=int, default=6)
    parser.add_argument("--x_factor", type=float, default=0.00001)
    parser.add_argument("--x_bias_factor", type=float, default=0.00001)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--weight_decay", type=float, default=0.005)
    parser.add_argument("--max_epochs", type=int, default=1)
    parser.add_argument("--T_max", type=int, default=7079)
    parser.add_argument("--dt_rank", type=int, default=0)

    parser.add_argument( "--flags", type=str, default="aou")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--run_tuner", action="store_true")
    parser.add_argument("--no_constant_B", action="store_true")
    parser.add_argument("--no_constant_C", action="store_true")
    parser.add_argument("--yes_convolution", action="store_true")
    parser.add_argument("--mode", type=str, default="last")
    parser.add_argument("--overfit", action="store_true")
    parser.add_argument("--no_compile", action="store_true")

    args = parser.parse_args()

    L.seed_everything(args.seed, workers=True)

    dt_rank = args.dt_rank if args.dt_rank > 0 else args.d_state

    options = {
        "net": {
            "output_dim": 10,
            "n_layer": args.num_layers,
            "d_model": args.d_model,
            "d_state": args.d_state,
            "flags": args.flags,
            "mode": args.mode,
            "x_factor": args.x_factor,
            "x_bias_factor": args.x_bias_factor,
            "dt_rank": dt_rank,
            "constant_B": not args.no_constant_B,
            "constant_C": not args.no_constant_C,
            "convolution1D": args.yes_convolution
        },
        "data": {
            "data_dir": args.data_dir,
            "batch_size": args.batch_size,
            "max_length": args.sequence_length
        },
        "trainer": {
            "max_epochs": args.max_epochs
        },
        "optimizer": {
            "lr": args.lr,
            "weight_decay": args.weight_decay
        },
        "scheduler": {
            "T_max": args.T_max
        }
    }

    time_taken, memory_used = evaluate_seq_length(options, args.run_tuner, args.overfit, args.no_compile)

    print(options)
    print(f"Time Taken: {time_taken}s")
    print(f"Memory Used: {memory_used/1e6}MB")
