import os
import time
import argparse
from argparse import ArgumentParser
import os

import json
import random
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.plugins import DeepSpeedPlugin
from models import load_model
from models.GPT2_Model_valid import GPT2Valid
from models.utils import MetricTracker


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


if __name__ == '__main__':
    pl.seed_everything(7, workers=True)
    # Parsing Arguments
    parser = ArgumentParser()
    parser.add_argument('--config', default=None, type=str)
    arg_ = parser.parse_args()
    if arg_.config == None:
        raise NameError("Include a config file in the argument please.")
    # Getting configurations
    config_path = arg_.config
    with open(config_path) as config_file:
        hparam = json.load(config_file)
    hparam = argparse.Namespace(**hparam)

    # Init configs that are not given
    if 'is_nsml' not in hparam:
        hparam.is_nsml = False
    if 'grad_norm' not in hparam:
        hparam.grad_norm = 0.1
    if 'weight_decay' not in hparam:
        hparam.weight_decay = 0.01
    if 'output_log' not in hparam:
        hparam.output_log = None
    if 'pred_log' not in hparam:
        hparam.pred_log = None
    if 'train_sets' not in hparam:
        hparam.train_sets = []
    if 'valid_sets' not in hparam:
        hparam.valid_sets = []
    if 'valid_type_path' not in hparam:
        hparam.valid_type_path = None
    if 'num_files' not in hparam:
        hparam.num_files = 1
    if 'len_data' not in hparam:
        hparam.len_data = None
    if 'learning_rate' not in hparam:
        hparam.learning_rate = None
    if 'negative_loss' not in hparam:
        hparam.negative_loss = False
    if 'gradient_accumulation_steps' not in hparam:
        hparam.gradient_accumulation_steps = 1
    if 'num_train_epochs' not in hparam:
        hparam.num_train_epochs = 0
    if 'use_lr_scheduling' not in hparam:
        hparam.use_lr_scheduling = False
    if 'num_workers' not in hparam:
        hparam.num_workers = 0
    if 'output_dir' not in hparam:
        hparam.output_dir = None
    if 'wandb_log' not in hparam:
        hparam.wandb_log = False
    if 'accelerator' not in hparam:
        hparam.accelerator = None
    if 'checkpoint_path' not in hparam:
        hparam.checkpoint_path = ''
    if 'resume_from_checkpoint' not in hparam:
        hparam.resume_from_checkpoint = None
    if 'fp16' not in hparam:
        hparam.fp16 = False
    if 'n_experts' not in hparam:
        hparam.n_experts = 1
    if 'rank' not in hparam:
        hparam.rank = 1
    if 'dataset_path' not in hparam:
        hparam.dataset_path = ''
    if 'check_val_every_n_epoch' not in hparam:
        hparam.check_val_every_n_epoch = 1
    if "val_check_interval" not in hparam:
        hparam.val_check_interval = 1.0
    if "eval_dataset" not in hparam:
        hparam.eval_dataset = ''
    if 'tokenizer' not in hparam:
        hparam.tokenizer = hparam.model
    if 'gpu_id' not in hparam:
        hparam.gpu_id = None
    if 'overfit_batches' not in hparam:
        hparam.overfit_batches = 0
    if 'split_name' not in hparam:
        hparam.split_name = 'train'
    if 'train_hf_datasets' not in hparam:
        hparam.train_hf_datasets = False 
    if 'subset_path' not in hparam:
        hparam.subset_path = None
    if 'warmup_steps' not in hparam:
        hparam.warmup_steps = 3000
    if 'injected_length' not in hparam:
        hparam.injected_length = None
    if 'train_doc_ids' not in hparam:
        hparam.train_doc_ids = None
    if 'valid_only_injected' not in hparam:
        hparam.valid_only_injected = False
    if 'use_selective_loss' not in hparam:
        hparam.use_selective_loss = False
    if 'soft_el_threshold' not in hparam:
        hparam.soft_el_threshold = 0
    if 'ma_threshold' not in hparam:
        hparam.ma_threshold = 0
    if 'min_train_epochs' not in hparam:
        hparam.min_train_epochs = 0
    if 'limit_val_batches' not in hparam:
        hparam.limit_val_batches = None
    if 'do_init_eval' not in hparam:
        hparam.do_init_eval = True if hparam.mode == 'negative_inject' else False
        # hparam.do_init_eval = False
    if 'swap_epoch' not in hparam:
        hparam.swap_epoch = [-9]
    # Handling resume from checkpoint
    # if hparam.resume_from_checkpoint and hparam.checkpoint_path!='':
    #    hparam.resume_from_checkpoint = hparam.checkpoint_path

    # Logging into WANDB if needed
    if hparam.wandb_log:
        wandb_logger = WandbLogger(
            project=hparam.wandb_project, name=hparam.wandb_run_name, entity='lklab_kaist')
    else:
        wandb_logger = None

    # Setting configurations
    args_dict = dict(
        is_nsml=hparam.is_nsml,
        output_dir=hparam.output_dir,  # Path to save the checkpoints
        dataset=hparam.dataset,
        dataset_path=hparam.dataset_path,
        train_hf_datasets=hparam.train_hf_datasets,
        subset_path=hparam.subset_path,
        train_sets=hparam.train_sets,
        train_doc_ids=hparam.train_doc_ids,
        valid_sets=hparam.valid_sets,
        valid_type_path=hparam.valid_type_path,
        valid_only_injected=hparam.valid_only_injected,
        split_name=hparam.split_name,
        eval_dataset=hparam.eval_dataset,
        num_files=hparam.num_files,
        len_data=hparam.len_data,
        model_name_or_path=hparam.model,
        method=hparam.method,
        n_experts=hparam.n_experts,
        rank=hparam.rank,
        mode=hparam.mode,
        tokenizer_name_or_path=hparam.tokenizer,
        max_input_length=hparam.input_length,
        max_output_length=hparam.output_length,
        injected_length=hparam.injected_length,
        learning_rate=hparam.learning_rate,
        negative_loss=hparam.negative_loss,
        weight_decay=hparam.weight_decay,
        adam_epsilon=1e-8,
        train_batch_size=hparam.train_batch_size,
        eval_batch_size=hparam.eval_batch_size,
        num_train_epochs=hparam.num_train_epochs,
        gradient_accumulation_steps=hparam.gradient_accumulation_steps,
        n_gpu=hparam.ngpu,
        gpu_id=hparam.gpu_id,
        num_workers=hparam.num_workers,
        resume_from_checkpoint=hparam.resume_from_checkpoint,
        use_lr_scheduling=hparam.use_lr_scheduling,
        val_check_interval=hparam.val_check_interval,
        check_val_every_n_epoch=hparam.check_val_every_n_epoch,
        fp16=hparam.fp16,
        opt_level='O1',  # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
        # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default
        max_grad_norm=hparam.grad_norm,
        check_validation_only=hparam.check_validation,
        checkpoint_path=hparam.checkpoint_path,
        accelerator=hparam.accelerator,
        output_log=hparam.output_log,
        pred_log=hparam.pred_log,
        # CUDA_VISIBLE_DEVICES=hparam.CUDA_VISIBLE_DEVICES,
        overfit_batches=hparam.overfit_batches,
        warmup_steps=hparam.warmup_steps,
        wandb_run_name=hparam.wandb_run_name,
        use_selective_loss=hparam.use_selective_loss,
        soft_el_threshold=hparam.soft_el_threshold,
        ma_threshold=hparam.ma_threshold,
        min_train_epochs=hparam.min_train_epochs,
        limit_val_batches=hparam.limit_val_batches,
        do_init_eval=hparam.do_init_eval,
        swap_epoch=hparam.swap_epoch
    )
    args = argparse.Namespace(**args_dict)

    # Setting how many model checkpoints to save
    saving_epoch = args.check_val_every_n_epoch
    # Defining how to save model checkpoints during training. Details: https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.model_checkpoint.html
    callbacks = [ModelCheckpoint(
                dirpath=args.output_dir, every_n_epochs=saving_epoch, save_top_k=-1)]
    checkpoint_callback = True

    if args.output_dir == "" or args.output_dir == "random" or args.output_dir == "distributed" or args.output_dir == "regular" or args.output_dir == 'weighted':
        checkpoint_callback = False  # Do not save model checkpoints when output dir is empty
        callbacks = []

    # Logging Learning Rate Scheduling
    if args.use_lr_scheduling and hparam.wandb_log:
        callbacks.append(pl.callbacks.LearningRateMonitor(
            logging_interval='step'))

    callbacks.append(MetricTracker(args.wandb_run_name, args.check_validation_only))

    reload_dataloaders_every_n_epochs = 1 if len(args.train_sets) > 1 else 0

    # Setting Flags for pytorch lightning trainer. Details: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags
    train_params = dict(
        accumulate_grad_batches=args.gradient_accumulation_steps,
        gpus=args.gpu_id if args.gpu_id else args.n_gpu,
        max_epochs=int(args.num_train_epochs * args.num_files),
        precision=16 if args.fp16 else 32,
        amp_backend="native",
        # resume_from_checkpoint=args.resume_from_checkpoint,
        gradient_clip_val=args.max_grad_norm,
        enable_checkpointing=checkpoint_callback,
        check_val_every_n_epoch=args.check_val_every_n_epoch,
        val_check_interval=args.val_check_interval,
        logger=wandb_logger,
        callbacks=callbacks,
        strategy=args.accelerator,
        overfit_batches=args.overfit_batches,
        num_sanity_val_steps=0,
        reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs,
        log_every_n_steps=1,
        deterministic=False,
        limit_val_batches=args.limit_val_batches
    )
    if 't5' in args.model_name_or_path:
        Model = load_model('T5')
    else:
        Model = load_model('GPT2')

    # if args.checkpoint_path != "" and args.resume_from_checkpoint == True:
    #     model = Model.load_from_checkpoint(
    #         checkpoint_path=args.checkpoint_path, hparams=args, strict=False)
    # else:
    #     model = Model(args)

    if args.check_validation_only:
        trainer = pl.Trainer(**train_params)
        # if args.checkpoint_path == "":
        #     trainer.validate(model)
        # else:
        #     trainer.validate(model, ckpt_path=hparam.resume_from_checkpoint)
        model = GPT2Valid(args)
        trainer.validate(model)
    else:
        trainer = pl.Trainer(**train_params)
        if args.do_init_eval:
            model = GPT2Valid(args)
            trainer.validate(model)
        model = Model(args)
        trainer.fit(model)