import os
os.environ['CUDA_LAUNCH_BLOCKING']="1"
from os.path import abspath, dirname
import numpy as np
from copy import copy
import torch
import os
import __main__  # used to get the original execute module

from models import model_utils
from models.pytorch_modelsize import SizeEstimator
from utils.buffer import Buffer
from utils.terminal_utils import ExperimentArgParse, logout, log_train, log_test


def setup_experiment(args):
    # initializes batch processors for training and validation
    train_bps = []
    valid_bps = []
    
    # initializes a buffer to store triplets from previous session
    buffer = Buffer(args.buffer_max_rel, args.device)
    
    for session in range(args.num_sess):
        # create training batch processor for each task
        train_args = copy(args)
        train_args.dataset += "_" + str(session)
        train_args.triplet2id = "train2id"
        train_args.session = session
        train_bps.append(model_utils.ERTrainBatchProcessor(train_args))
        
        # create evaluation batch processor for session
        dev_args = copy(args)
        dev_args.neg_ratio = 0
        dev_args.dataset += "_" + str(session)
        if dev_args.sess_mode == "TEST":
            dev_args.triplet2id = "test2id"
        else:
            dev_args.triplet2id = "valid2id"
        dev_args.dataset_fps = [train_bp.dataset.fp for train_bp in train_bps]
       
        valid_bps.append(model_utils.ValidBatchProcessor(dev_args))
    
    # initializes a single model and optimizer used across sessions
    model_optim_args = copy(args)
    model_optim_args.num_ents = len(train_bps[0].dataset.e2i)
    model_optim_args.num_rels = len(train_bps[0].dataset.r2i)
    model = model_utils.init_model(model_optim_args)
    optimizer = model_utils.init_optimizer(model_optim_args, model)
    
    log_dir = abspath(dirname(__file__)) + "/logs/" + args.checkpoint_name
    ckp_dir = abspath(dirname(__file__)) + "/ckps/" + args.checkpoint_name
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(ckp_dir):
        os.makedirs(ckp_dir)

    return buffer, train_bps, valid_bps, model, optimizer, log_dir


def setup_train_session(buffer, sess, args, model, optim, train_bps):
    """
    performs pre-training session operations, 
    load buffer data to extend the training set
    """
    train_bp = train_bps[sess]             # current session
    
    initial_num_train_triples = copy(train_bp.dataset.triples.shape[0])
    num_generated_train_triples = 0
    
    # for tasks except the task_0
    if sess:
        # load best model from prior session
        load_args = copy(args)
        load_args.tag = os.path.basename(__main__.__file__).split(".")[0]
        load_args.sess = str(sess - 1)
        model = model_utils.load_model(load_args, model)

        # load previous data from buffer and reinit the dataloader
        train_bp.dataset.load_buffer(buffer)
        train_bp.reset_data_loader()
    
        optim = model_utils.init_optimizer(copy(args), model)

    # generates early stop tracker for training
    tracker_args = copy(args)
    tracker_args.tag = os.path.basename(__main__.__file__).split(".")[0]
    tracker_args.sess = str(sess)
    tracker = model_utils.EarlyStopTracker(tracker_args)

    # calculates model param size
    se = SizeEstimator(copy(args))
    model_params_size = se.estimate_size(model)[0]
    del se
    logout("Mem stats:" + str(model_params_size))

    return buffer, model, optim, train_bp, tracker, model_params_size


def setup_test_session(sess, args, model):
    """
    performs pre-testing session operation to load the model
    """
    # loads best model for session
    load_args = copy(args)
    load_args.tag = os.path.basename(__main__.__file__).split(".")[0]
    load_args.sess = str(sess)
    model = model_utils.load_model(load_args, model)

    return model


if __name__ == "__main__":
    exp_parser = ExperimentArgParse("Continual setting experiment")
    exp_args = exp_parser.parse()

    # selects hardware to use
    if exp_args.cuda and torch.cuda.is_available():
        logout("Running with CUDA")
        exp_args.device = torch.device('cuda')
    else:
        logout("Running with CPU", "w")
        exp_args.device = torch.device('cpu')

    if exp_args.sess_mode == "TRAIN":
        logout("Training running...", "i")
        buffer, exp_train_bps, exp_valid_bps, exp_model, exp_optim, log_dir = setup_experiment(exp_args)
        
        # for each session
        for exp_sess in range(exp_args.num_sess):
            buffer, exp_model, exp_optim, exp_train_bp, exp_tracker, model_stats = \
                setup_train_session(buffer, exp_sess, exp_args, exp_model,
                                       exp_optim, exp_train_bps)
            
            while exp_tracker.continue_training():
                # validate
                if exp_tracker.validate():
                    inf_metrics = model_utils.evaluate_model(exp_args, exp_sess, exp_valid_bps, exp_model)
                    # log inference metrics
                    log_label = "epoch" + str(exp_tracker.get_epoch())
                    log_train(inf_metrics, exp_tracker.get_epoch(),
                              exp_sess, exp_args.num_sess, log_label,
                              model_stats, log_dir)
                    # update tracker for early stopping & model saving
                    exp_tracker.update_best(exp_sess, inf_metrics, exp_model)

                # train
                total_loss, epoch_buffer = exp_train_bp.process_epoch(exp_model, exp_optim)
                # union buffer: buffer <- the intersection of buffer and epoch_buffer
                buffer.update_per_epoch(epoch_buffer)
                
                exp_tracker.step_epoch()

            # logs the final performance for session (i.e. best)
            best_metrics, best_epoch = exp_tracker.get_best()
            log_train(best_metrics, best_epoch, exp_sess, exp_args.num_sess,
                      "best_epoch", model_stats, log_dir)

    elif exp_args.sess_mode == "TEST":
        logout("Testing running...", "i")
        buffer, exp_train_bps, exp_valid_bps, exp_model, exp_optim = setup_experiment(exp_args)

        for exp_sess in range(exp_args.num_sess):
            exp_model = setup_test_session(exp_sess, exp_args, exp_model)
            inf_metrics = model_utils.evaluate_model(exp_args, exp_sess, exp_valid_bps, exp_model)
            log_test(inf_metrics, exp_sess, exp_args.num_sess, log_dir)

    else:
        logout("Mode not recognized for this setting.", "f")
