import argparse


class ARGS(object):

    @staticmethod
    def read_args():
        parser = argparse.ArgumentParser()
        parser.add_argument('--batch_size', default=100, type=int)
        parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
        parser.add_argument('--eval', default=50, type=int, help='evaluate model every _ epoch')
        parser.add_argument('--epochs', default=4000, type=int, help='maximum number of iterations')

        # model setting
        parser.add_argument('--emb_dim', default=50, help='dimension of the relation and entity embeddings')
        parser.add_argument('--emb_name', default='random', type=str, help='type of the initialization '
                                                                           'could be random, ReNet, TTransE and TATransE')
        parser.add_argument('--finetune', action='store_true',
                            help='If initialized relation and entity emb should be tuned')

        parser.add_argument('--h_dim', default=50, type=int, help='hidden dimension of the model for gru and gcn')
        parser.add_argument('--out_dim', default=50, type=int, help='output dimension of the seq encoder which also declares'
                                                          ' the input dimension of the meta learner')
        parser.add_argument('--snap_encoder', default=1, type=int, help='snapshot encoder 0: simple mean, 1: gcn mean')
        parser.add_argument('--enc_dropout', default=0.1, type=float)
        parser.add_argument('--seq_encoder', default='gru', type=str, help='snapshot encoder type, gru or att')
        parser.add_argument('--weight_decay', default=0.01, type=float, help='weight_decay')
        parser.add_argument('--n_head', default=2, type=int, help='number of heads for self attention')
        parser.add_argument('--max_n', default=10, type=int, help='number of neighbors')
        parser.add_argument('--hist_len', default=10, type=int, help='history length')
        parser.add_argument('--sequential',  action='store_true', help='If True, the time difference is kept')
        parser.add_argument('--mask',  action='store_true', help='If True, masking for attention is enabaled')

        # meta learner setting
        parser.add_argument('--meta_type', default='matcher', type=str, help='either matcher or protonet')
        parser.add_argument('--meta_dropout', default=0.1, type=float)
        parser.add_argument('--steps', default=3, type=int, help='matching net number of steps')
        parser.add_argument('--shots', default=1, type=int)

        # loss
        parser.add_argument('--margin', default=5, type=float)
        parser.add_argument('--alpha', default=0, type=float,
                            help='regularizer coefficient, set to 0 if not regularizing')
        parser.add_argument('--norm', default='l2', type=str, help='norm type, l2 or l1')

        # general setting
        parser.add_argument('--cuda', action='store_true', help='use gpu if true')
        parser.add_argument('--run_num', default=0, type=int,
                            help='change this parameter when running the model with '
                                 'same params multiple times')
        parser.add_argument('--log', default=5, type=int, help='log every _ epochs')
        parser.add_argument('--actor_mode', action='store_true', help='If graph is in actor or country mode')
        parser.add_argument('--dataset', default='icews', help='dataset can be gdelt or icews')

        # test setting
        parser.add_argument('--test', action='store_true', help='test mode')
        parser.add_argument('--checkpoint', default=0, type=int, help='checkpoint number to test')
        args = parser.parse_args()

        args.graph_mode = 'actor' if args.actor_mode else 'country'
        return args


class PATHS(object):
    DATA_DIR = './data/'
    ICEWS_MAIN_DATA = DATA_DIR + 'icews.dedup.csv'
    GDELT_MAIN_DATA = DATA_DIR + 'gdelt.dateadded.2018-01-01.csv'
    DATASET_DIR = DATA_DIR + '%s_%s/'

    EMB_PATH = DATASET_DIR + 'embeddings/'
    BASELINES_DIR = DATASET_DIR + 'baselines/'
    HIST_DIR = DATASET_DIR + 'hist_%d_%d/'
    DATA_2ID = 'data2id.csv'

    PRETRAIN_DATA = 'pretrain.csv'
    FEWSHOT_DATA = 'fewshot.csv'
    FEWSHOT_QUADS = 'fewshot.txt'
    SYMBOL_IDS = 'symbols2id.pkl'
    IDS_SYMBOL = 'id2symbols.pkl'
    TRAIN_ENT = 'train_ent.pkl'
    HIST_GRAPH = 'hist_graphs.pkl'
    HIST_SUBJ = 'hist_s.pkl'
    HIST_OBJ = 'hist_o.pkl'
    HIST_ALL_ENT = 'hist_ent.pkl'
    TRAIN_TASK_POOL = 'meta_train.pkl'
    TEST_TASK_POOL = 'meta_test.pkl'
    VAL_TASK_POOL = 'meta_val.pkl'
    LOG_DIR = './logs/'
    SAVE_PATH = 'saved_models/'
