import torch
from dataloader import *
import numpy as np
import argparse
from pytorch_models.utils.logger import *
import json
import os
import random
from utils import *
from pytorch_models.trainer.trainer import ModelTrainer
from pytorch_models.optimizer.super_adam_pytorch import BertAdam


def get_params():
    parser = argparse.ArgumentParser(description='tomita_lstm')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--suffix', default="")
    #data
    parser.add_argument('--dataset', type = str, default=None)
    parser.add_argument('--vocab', type = str, default="vocab_10000")
    parser.add_argument('--grammar', type=str, default="tomita-1")
    #training
    parser.add_argument('--iterations', type=int, default=1000)
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--valid_batch_size', type=int, default=-1)
    parser.add_argument('--save_model_iter', type=int, default=-1)
    parser.add_argument('--valid_iter', type=int, default=10)
    parser.add_argument('--load_model', type=str, default=None)
    parser.add_argument('--resume', type=str, default=None)
    parser.add_argument('--partial_strings', action="store_true")
    parser.add_argument('--early_stop', type = float, default = None)
    parser.add_argument('--stopping_patience', type = int, default = 1000)
    #optimizer
    parser.add_argument('--optim', default="sgd")
    parser.add_argument('--noise_var', type=float, default=0)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--lr_scheduler', type=str, default=None)
    parser.add_argument('--alpha', type=float, default=0.95)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--rmsprop_centering', type=bool, default=False)
    parser.add_argument('--l2', type=float, default=0)
    parser.add_argument('--l1_lambda', type = float, default=0)
    #model
    parser.add_argument('--state_dim', type=int, default=32)
    parser.add_argument('--in_dim', type=int, default=32)
    parser.add_argument('--model', default="rnn_lm")
    parser.add_argument('--dropout', type=float, default=0)
    parser.add_argument('--nlayers', type=int, default=1)
    #rnn
    parser.add_argument('--rnn_cell', default=None)
    parser.add_argument('--mtype', default="seq")
    parser.add_argument('--hinit', type = int, default= 0)

    #others:
    parser.add_argument('--print_iter', type = int, default = 1)

    #parser.add_argument('--attention', type = bool, default=False)
    
    #transformer
    #parser.add_argument('--position_encoding', default=None)
    parser.add_argument('--heads', type=int, default=8)
    
    #loss     
    parser.add_argument('--loss_func', type=str, default="ce")
    parser.add_argument('--cmargin', type=float, default=0.5)
    
    # parser.add_argument('--contrastive_pos_lp', type=int, default=1)
    # parser.add_argument('--contrastive_neg_lp', type=int, default=1)

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    session_args = get_params()

    torch.manual_seed(session_args.seed)
    np.random.seed(session_args.seed)
    random.seed(session_args.seed)

    torch.backends.cudnn.deterministic = True
    #torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False
    
    model_name = "dfa_stability_{}".format(session_args.model)
    if session_args.rnn_cell is not None:
        suffix = "{0}_{1}_{2}_seed_{3}".format(model_name, session_args.rnn_cell, session_args.suffix, session_args.seed)
    else:
        suffix = "{0}_{1}_seed_{2}".format(model_name, session_args.suffix, session_args.seed)
    lfname = '../logs/train_{0}.log'.format(suffix)
    writer_fname = "../tblogs/{}".format(suffix)
    
    logger = create_logger(suffix, lfname, seed = session_args.seed, resume = True if session_args.resume is not None else False )
    writer = create_summary_writer(writer_fname, resume = True if session_args.resume is not None else False )

    logger.info(session_args)

    #================================================================================================================================
    # Dataset, vocab 
    #================================================================================================================================
    session_args.shuffle_data = True
    batch_size = session_args.batch_size
    if session_args.valid_batch_size == -1:
        session_args.valid_batch_size = batch_size
    
    vocab = get_vocab(session_args)

    vocab_size = len(vocab)
    session_args.vocab_size = vocab_size
    
    train_data, val_data = load_dataset(session_args, logger)
    
    #================================================================================================================================
    # Model 
    #================================================================================================================================

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    
    session_args.device = device
    model = load_model(session_args)

    _n_params = model.count_parameters()
    logger.info("Model Size : {}".format(_n_params))

    #================================================================================================================================
    # Optimizer and Scheduler
    #================================================================================================================================

    momentum = session_args.momentum
    alpha = session_args.alpha
    L2 = session_args.l2
    rmsprop_centering = session_args.rmsprop_centering
    
    if session_args.optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=session_args.lr, momentum=momentum, weight_decay=L2)
    elif session_args.optim == "bertadam":
        optimizer = BertAdam(model.parameters(), lr=session_args.lr, logger=logger, warmup=0.1, t_total=session_args.iterations)
    elif session_args.optim == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=session_args.lr, alpha=alpha,
                                        eps=1e-08, weight_decay=L2,
                                        momentum=momentum, centered=rmsprop_centering)
    elif session_args.optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=session_args.lr, betas=(0.9, 0.995), eps=1e-09,
                                weight_decay=L2, amsgrad=True)

    logger.debug("optimizer created")

    lr_scheduler = None
    if session_args.lr_scheduler == "one_cycle_lr":
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, session_args.lr, total_steps=session_args.iterations, epochs=None, steps_per_epoch=None, pct_start=0.3, 
                                            anneal_strategy='cos', cycle_momentum=True, base_momentum=0.85, max_momentum=0.95, 
                                            div_factor=25.0, final_div_factor=10000.0, three_phase=False, last_epoch=-1, verbose=True)
    elif session_args.lr_scheduler == "cosine_annealing_lr":
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(session_args.iterations/10), eta_min=0, last_epoch= -1, verbose=True)
         
    logger.debug("schedular created")

    #================================================================================================================================
    # Loss Function 
    #================================================================================================================================

    print(session_args.loss_func)
    loss_func = get_loss_func(session_args)
    model.register_loss_function(loss_func, l1_lambda = session_args.l1_lambda)

    #================================================================================================================================
    # Trainer
    #================================================================================================================================
    model_save_path = "../models/{}".format(suffix)
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)

    trainer_args = {
        "model": model,
        "trainable_named_params": model.named_parameters,
        "logger": logger,
        "train_data": train_data,
        "valid_data": val_data,
        "iterations": session_args.iterations,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "print_iter": session_args.print_iter,
        "save_model_iter": session_args.save_model_iter,
        "valid_iter": session_args.valid_iter,
        "stopping_patience":session_args.stopping_patience,
        "tb_iter": 1,
        "tb_writer": writer,
        "model_folder": model_save_path,
        "model_prefix": "{}".format(suffix),
        "grad_clip_norm": -1,
        "grad_clip_val": -1,
        "gnoise": [0, 0],
        "epoch_update": False,
        "train_eval": False,
        "device" : device
    }

    if session_args.early_stop is not None:
        criterion = create_early_stopping_criteria(session_args.early_stop)
        trainer_args["early_stopping"] = criterion
    
    trainer = ModelTrainer(trainer_args)    
        
    if session_args.load_model is not None:
        trainer.load_state(session_args.load_model, train_state=True)

    elif session_args.resume == "best":
        trainer.load_state(os.path.join(model_save_path, "best.pth"), train_state = True)
    
    elif session_args.resume == "last":
        trainer.load_state(os.path.join(model_save_path, "last.pth"), train_state = True)        


    #================================================================================================================================
    # Initial Parameter Statistics
    #================================================================================================================================

    with torch.no_grad():
        for param_tuple in trainer.trainable_named_params():
            if param_tuple[0][-4:] != "bias":
                name = param_tuple[0]
                mean = torch.mean(param_tuple[1])
                std = torch.std(param_tuple[1])
                mn = torch.min(param_tuple[1])
                mx = torch.max(param_tuple[1])
                logger.info("==> {0} mean : {1:.5f} std : {2:.5f} min : {3:.5f} max : "
                            "{4:.5f}".format(name, mean, std, mn, mx))
                if session_args.noise_var > 0:
                    param_tuple[1].add_(
                        torch.normal(0, session_args.noise_var, size=param_tuple[1].shape).to(
                            param_tuple[1].device))
        
    
    #================================================================================================================================
    # Train
    #================================================================================================================================

    trainer.train()
    writer.close()
