import argparse
import os
import sys
import time
import pickle
import torch
import logging
import json
import numpy as np

from collections import defaultdict
from torch.utils.data import DataLoader
from collections import OrderedDict
from tqdm import tqdm

from cl_frameworks.edge_augmenter import EdgeAugmenter
from dataset.baseDataset import baseDataset, QuadruplesDataset
from model.agent import Agent
from model.enhancement_layer import EnhancementLayer
from model.environment import Env
from model.episode import Episode
from utils import *
from utils.tester import Tester

# Get args with argparse containing start and end index of the datasets, the start and end of the model index, checkpoint path
# and the path to the dataset

def parse_prefix(prefix, args):
    params = prefix.split('_')[:-1]
    local_vars = {"mean": np.mean}
    for i in range(len(params)):
        if i % 2 == 0:
            try:
                setattr(args, params[i], eval(params[i+1], {}, local_vars))
            except NameError:
                setattr(args, params[i], str(params[i+1]))
    # import IPython;
    # IPython.embed()
    return args

parser = argparse.ArgumentParser(description='Continual learning for KGE')
parser.add_argument('--cuda', action='store_true', help='whether to use GPU or not.')
parser.add_argument('--start', type=int, default=1, help='start index of the dataset')
parser.add_argument('--end', type=int, default=15, help='end index of the dataset')
parser.add_argument('--start_model', type=int, default=1, help='start index of the model')
parser.add_argument('--end_model', type=int, default=15, help='end index of the model')

parser.add_argument('--checkpoint_path', type=str, default='checkpoints', help='path to save the checkpoint')
parser.add_argument('--data_path', type=str, default='data', help='path to the dataset')
parser.add_argument('--dataset_name', type=str, default='ICEWS1807', help='path to the dataset')
parser.add_argument('--state_actions_path', default='state_actions_space.pkl', type=str,
                        help='the file stores preprocessed candidate action array.')
parser.add_argument('--num_workers', default=8, type=int, help='workers number used for dataloader.')
parser.add_argument('--time_span', default=24, type=int, help='24 for ICEWS, 1 for WIKI and YAGO')

# Edge Augmentation Params
parser.add_argument('--augment_fname', default='augmented_edges.txt', type=str, help='augmented edges file name.')

parser.add_argument('--test_batch_size', default=1, type=int,
                        help='test batch size, it needs to be set to 1 when using IM module.')
parser.add_argument('--beam_size', default=100, type=int, help='the beam number of the beam search.')
parser.add_argument('--test_inductive', action='store_true', help='whether to verify inductive inference performance.')
parser.add_argument('--IM', action='store_true', help='whether to use IM module.')
parser.add_argument('--IIM', action='store_true', help='whether to use Inference IM module (original implementation).')
parser.add_argument('--mu', default=0.1, type=float, help='the hyperparameter of IM module.')
parser.add_argument('--prefix', default='None', type=str, help='Prefix for the experiment name that contains other hyperparameters')

args = parser.parse_args()
args = parse_prefix(args.prefix, args)
print(args)
# args.end = args.start + 1
args.end_model = args.start_model + 1

# data_dir = os.path.join(args.dataset_path, args.dataset_name)

if not os.path.exists(os.path.join(args.data_path, args.state_actions_path)):
    state_action_space = None
    logging.info('State action space not found')
else:
    state_action_space = pickle.load(open(os.path.join(args.data_path, args.state_actions_path), 'rb'))
    logging.info('State action space loaded from {}'.format(args.state_actions_path))
    print('State action space loaded from {}'.format(args.state_actions_path))

baseData = None
ranks_path = os.path.join(args.checkpoint_path, 'ranks')
if not os.path.exists(ranks_path):
    os.makedirs(ranks_path)
# creating a numpy matrix to store the results
results = {}
for k in ['MRR', 'HITS@1', 'HITS@3', 'HITS@10']:
    results[k] = np.zeros((args.end - args.start, args.end_model - args.start_model))

if getattr(args, 'augment', False):
    edge_augmentor = EdgeAugmenter(data_dir=args.checkpoint_path, augmented_edges_path=args.augment_fname)


for d in range(args.start, args.end):
    # load data
    snpashot = 'g%d'%d
    testpath = os.path.join(args.data_path, '%s/test.txt'%snpashot)
    trainpath = os.path.join(args.data_path, '%s/train.txt'%snpashot)
    validpath = os.path.join(args.data_path, '%s/valid.txt'%snpashot)
    if not baseData:
        baseData = baseDataset(trainpath, testpath, os.path.join(args.data_path, 'stat.txt'), validpath, args)
    else:
        baseData.update(trainpath, testpath, validpath)

    testRelAdj = baseData.getBatchRelAdj(baseData.testQuadruples, getattr(args, 'neighbor', 10))
    testDataset = QuadruplesDataset(baseData.testQuadruples, baseData.num_r, testRelAdj, getattr(args, 'IM', False))
    test_dataloader = DataLoader(
        testDataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )
    for m in range(args.start_model, args.end_model):
        print(args.checkpoint_path)
        # if 'checkpoints_weighted_True_prob_mean_fraction_0.5_sgd_0.0_0' in args.checkpoint_path and m not in [8, 14]:
        #     print('hi')
        #     continue
        # elif 'checkpoints_weighted_True_prob_max_fraction_0.2_sgd_0.0_0' in args.checkpoint_path and m not in [6, 13]:
        #     continue
        # elif 'checkpoints_weighted_True_prob_mean_fraction_0.8_sgd_0.0_0' in args.checkpoint_path and m not in [10, 13]:
        #     continue
        # elif m not in [3, 4, 6, 9, 11]:
        #     continue
        # print('passed the continue')
        # Do not evaluate the models on the future datasets
        # if m < d:
        #     continue
        # load model
        # model_name = 'checkpoint_final_step_%d.pth'%m
        model_name = 'checkpoint_step_%d.pth'%m
        model_path = os.path.join(args.checkpoint_path, model_name)
        params = torch.load(model_path)

        if getattr(args, 'augment', False) and m > 1:
            augmented_edges = edge_augmentor.load_generated_edges(step=m - 1)
            # import IPython
            # IPython.embed()
            baseData.add_augmented_edges(augmented_edges)
        # import IPython;

        # IPython.embed()
        with open(os.path.join(args.checkpoint_path, 'config.json'), 'r') as f:
            config = json.load(f)
        config['num_ent'] = baseData.num_e
        config['num_rel'] = baseData.num_r
        config['action_dim'] = config['ent_dim'] + config['rel_dim']  # dimension of the actions
        config['mlp_input_dim'] = config['ent_dim'] + config['rel_dim'] + config['state_dim']
        config['mlp_hidden_dim'] = config['hidden_dim']

        agent = Agent(config)
        env = Env(baseData.allQuadruples, config, state_action_space)

        if getattr(args, 'augment', False) and m > 1:
            env.update_state_actions_space(baseData.augmented_edges)

        # enhancement_layer = EnhancementLayer(ent_embeds=agent.ent_embs,
        #                                      rel_embeds=agent.rel_embs,
        #                                      degree=baseData.degree,
        #                                      temporal_rel_adj=baseData.tempRelAdj,
        #                                      config=config)

        enhancement_layer = EnhancementLayer(ent_embeds=agent.ent_embs,
                                             rel_embeds=agent.rel_embs,
                                             degree=baseData.degree,
                                             time_span=args.time_span,
                                             config=config)

        if args.cuda:
            enhancement_layer = enhancement_layer.cuda()

        episode = Episode(env, agent, enhancement_layer, config)
        if args.cuda:
            episode = episode.cuda()

        # import IPython; IPython.embed()
        # if 'IM' in config:
        #     state_dict = OrderedDict()
        #     for k, v in params['model_state_dict'].items():
        #         if k.startswith('enhancement_layer.model_embeds') or k.startswith('enhancement_layer.relation_embeds'):
        #             continue
        #         else:
        #             new_key = k
        #         state_dict[new_key] = v
        # else:
        state_dict = params['model_state_dict']
        episode.load_state_dict(state_dict)
        logging.info('Checkpoint loaded from {}'.format(model_path))
        # test model
        logging.info('Start Testing model {} on data {}'.format(m, d))
        tester = Tester(episode, args, baseData.seen_entities, baseData.RelEntCooccurrence)
        metrics, ranks = tester.test(test_dataloader,
                              testDataset.__len__(),
                              baseData.skip_dict,
                              config['num_ent'])
        # save ranks
        np.save(os.path.join(ranks_path, 'ranks_model_%d_data_%d_inductive_%s.npy'%(m, d, args.test_inductive)), ranks)

        for k in metrics.keys():
            results[k][d - args.start][m - args.start_model] = metrics[k]

        for mode in metrics.keys():
            logging.info('Test {} : {}'.format(mode, metrics[mode]))

        np.save(os.path.join(ranks_path,
                             'results__model_%d_data_%d_inductive_%s.npy' % (m, d, args.test_inductive)), metrics)

# save results

np.save(os.path.join(ranks_path, 'results_inductive_%s.npy'%(args.test_inductive)), results)