from os.path import basename
from argparse import ArgumentParser
from copy import copy
from sys import stdin
from select import select
import torch
import numpy as np
import __main__  # used to get the original execute module

type2color = {
    's': ' \033[95mSuccess:\033[00m {}',
    'i': ' \033[94mInfo:\033[00m {}',
    'd': ' \033[92mDebug:\033[00m {}',
    'w': ' \033[93mWarning:\033[00m {}',
    'e': ' \033[91mError:\033[00m {}',
    'f': ' \033[4m\033[1m\033[91mFatal Error:\033[00m {}'
}


def logout(msg, p_type=''):
    """ Provides coloring debug printing to terminal """
    if not p_type.lower() in type2color:
        start = type2color['d']
    else:
        start = type2color[p_type.lower()]
    print(start.format(msg))
    
def _log_train(log_dir, log_str, log_value):
    with open(log_dir+"/performances.csv", "a") as f:
        f.write(str(log_str) + str(log_value) + "\n")

def log_train(performances, epoch, session, num_sess, stage, model_size, log_dir):
    template_performances = np.zeros((num_sess, 2))
    with open(log_dir+"/performances.csv", "a") as f:
        f.write(str(session) + ",")
        f.write(str(stage) + ",")
        f.write(str(epoch) + ",")
        for row in range(template_performances.shape[0]):
            if row < performances.shape[0]:
                for value in performances[row, :]:
                    f.write(str(value) + ",")
            else:
                for value in template_performances[row, :]:
                    f.write(str(value) + ",")
        f.write("\n")
        if stage == 'best_epoch':
            avg_mrr = avg_hit10 = 0
            for row in range(performances.shape[0]):
                avg_mrr += performances[row, 0]
                avg_hit10 += performances[row, 1]
            avg_mrr /= performances.shape[0]
            avg_hit10 /= performances.shape[0]
            f.write(str(session) + ",")
            f.write(str(stage) + ",")
            f.write(str(epoch) + ",")
            f.write('avg_mrr: ' + str(avg_mrr) + ",")
            f.write('avg_hit10: ' + str(avg_hit10) + ",")
            f.write("\n")

def log_test(performances, session, num_sess, log_dir):
    template_performances = np.zeros((num_sess, 2))
    with open(log_dir+"/test.csv", "a") as f:
        f.write(str(session) + ",")
        f.write("test,")
        for row in range(template_performances.shape[0]):
            if row < performances.shape[0]:
                for value in performances[row, :]:
                    f.write(str(value) + ",")
            else:
                for value in template_performances[row, :]:
                    f.write(str(value) + ",")
        f.write("\n")
        
        # write average mrr & hits@10
        avg_mrr = avg_hit10 = 0
        for row in range(performances.shape[0]):
            avg_mrr += performances[row, 0]
            avg_hit10 += performances[row, 1]
        avg_mrr /= performances.shape[0]
        avg_hit10 /= performances.shape[0]
        f.write(str(session) + ",")
        f.write("test,")
        f.write('avg_mrr: ' + str(avg_mrr) + ",")
        f.write('avg_hit10: ' + str(avg_hit10) + ",")
        f.write("\n")


class ExperimentArgParse:
    def __init__(self, description):
        self.parser = ArgumentParser(description=description)
        self.parser.add_argument('--dataset', type=str, default="icews05-15_class_il6",
                                 required=False, help='Dataset name: icews05-15_clasll_il6 by default')
        self.parser.add_argument('--sess_mode', type=str, default="TRAIN",
                                 required=False, help='Session Mode: TRAIN,TEST')
        self.parser.add_argument('--neg_ratio', type=int, default=25,
                                 required=False, help='Negative sampling Ratio')
        self.parser.add_argument('--batch_size', type=int, default=64,
                                 required=False, help='Batch size')
        self.parser.add_argument('--hidden_size', type=int, default=20,
                                 required=False, help='hidden dim size')
        self.parser.add_argument('--margin', type=float, default=5.0,
                                 required=False, help='ranking difference margin')
        self.parser.add_argument('--opt_method', type=str, default='adam',
                                 required=False, help='Optimization Method to use')
        self.parser.add_argument('--lr', type=float, default=3e-4,
                                 required=False, help='learning rate')
        self.parser.add_argument('--wd', type=float, default=0, help='weight decay')
        self.parser.add_argument('--num_workers', type=int, default=16,
                                 required=False, help='Number of cpu Worker threads batching data')
        self.parser.add_argument('--num_epochs', type=int, default=1000,
                                 required=False, help='Number of Epochs to train')
        self.parser.add_argument('--cuda', type=int, default=1,
                                 required=False, help='1 indicates Run on GPU')
        self.parser.add_argument('--valid_freq', type=int, default=1,
                                 required=False, help='Evaluation Frequency')
        self.parser.add_argument('--num_sess', type=int, default=6,
                                 required=False, help='Number of learning Sessions for icews05-15')
        self.parser.add_argument('--exp_name', type=str, default="dev",
                                 required=False, help='experiment description')
        self.parser.add_argument('--patience', type=int, default=30,
                                 required=False, help='Early stop Patience')
        self.parser.add_argument('--buffer_max_rel', type=int, default=100,
                                 required=False, help='Max num of rels can be stored in buffer')
        self.parser.add_argument('--replay_per_rel', type=int, default=32,
                                 required=False, help='number of triplets replayed for each relation')
        self.parser.add_argument('--topk', type=int, default=8,
                                 required=False, help='number of triplets selected for each relation in one batch')
        self.parser.add_argument("--max_neighbor", type=int, default=50,
                                 required=False, help='maximum number considered in aggregator')
        self.parser.add_argument("--max_nn_meta", type=int, default=1,
                                 required=False, help='number of negative meta')
        self.parser.add_argument("--job_id", type=int, default=0, help='jobid')
        self.parser.add_argument("--coeff_info", type=float, default=1, help='coefficient of info loss')
        self.parser.add_argument("--coeff_l2", type=float, default=1, help='coefficient of l2 loss')
        

    def parse(self):
        """ prints the current command-line options set, waiting 10 s before continuing """
        parsed_args = self.parser.parse_args()
        logout("The current " + parsed_args.sess_mode + "ing parameters are: \n" + str(parsed_args),"i")
        
        checkpoint_name = str(parsed_args.dataset) + "_TransE_ER_replay"
        checkpoint_name += str(parsed_args.replay_per_rel) + "_"
        checkpoint_name += parsed_args.exp_name
        checkpoint_name += "_JOBID_" + str(parsed_args.job_id)
        
        parsed_args.checkpoint_name = checkpoint_name

        return parsed_args


from models import model_utils
from utils import data_utils