import argparse
import logging

import numpy as np
import torch
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler

from cl_frameworks.edge_augmenter import EdgeAugmenter
from cl_frameworks.ewc import EWC
from cl_frameworks.experience_replay import ExperienceReplay
from dataset.DualSampler import DualSampler
from model.enhancement_layer import EnhancementLayer
from utils.logger import *
from utils.trainer import Trainer
from utils.tester import Tester
from dataset.baseDataset import baseDataset, QuadruplesDataset
from model.agent import Agent
from model.environment import Env
from model.episode import Episode
from model.policyGradient import PG
from model.dirichlet import Dirichlet
import os
import pickle
import wandb


os.environ['WANDB_API_KEY'] = ""
wandb.login("")

import torch
import numpy as np
import random

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

def parse_prefix(prefix, args):
    params = prefix.split('_')[:-1]
    local_vars = {"mean": np.mean}
    for i in range(len(params)):
        if i % 2 == 0:
            try:
                setattr(args, params[i], eval(params[i+1], {}, local_vars))
            except NameError:
                setattr(args, params[i], str(params[i+1]))
    # import IPython;
    # IPython.embed()
    return args





def parse_args(args=None):
    parser = argparse.ArgumentParser(
        description='Training and Testing Temporal Knowledge Graph Forecasting Models',
        usage='main.py [<args>] [-h | --help]'
    )

    parser.add_argument('--cuda', action='store_true', help='whether to use GPU or not.')
    parser.add_argument('--data_path', type=str, default='data/ICEWS14', help='Path to data.')
    parser.add_argument('--do_train', action='store_true', help='whether to train.')
    parser.add_argument('--do_test', action='store_true', help='whether to test.')
    parser.add_argument('--train_on_test', action='store_true', help='whether to train on validation and test.')
    parser.add_argument('--save_path', default='checkpoints', type=str, help='log and model save path.')
    parser.add_argument('--load_path', default='checkpoints', type=str, help='trained model checkpoint path.')

    # Edge Augmentation Params
    parser.add_argument('--augment_fname', default='augmented_edges.txt', type=str, help='augmented edges file name.')

    # Experience Replay Params
    parser.add_argument('--er', action='store_true', help='whether to use experience replay.')
    parser.add_argument('--max_buffer_size', default=12000, type=int, help='experience replay size.')
    parser.add_argument('--buffer_fname', default='buffer.txt', type=str, help='experience replay file name.')
    parser.add_argument('--er_sample_ratio', default=0.2, type=float, help='sample ratio for experience replay.')
    parser.add_argument('--er_batch_ratio', default=0.2, type=float, help='batch ratio for experience replay.')

    # EWC Params
    parser.add_argument('--ewc', action='store_true', help='whether to use EWC.')
    parser.add_argument('--ewc_lambda', default=0.01, type=float, help='EWC lambda.')
    parser.add_argument('--ewc_params_path', default='ewc_params.pkl', type=str, help='EWC params path.')

    # Train Params
    parser.add_argument('--batch_size', default=512, type=int, help='training batch size.')
    parser.add_argument('--start', default=1, type=int, help='training start timestamp.')
    parser.add_argument('--end', type=int, help='training end timestamp.')
    parser.add_argument('--max_epochs', default=400, type=int, help='max training epochs.')
    parser.add_argument('--start_epoch', default=0, type=int, help='start training epochs.')
    parser.add_argument('--num_workers', default=8, type=int, help='workers number used for dataloader.')
    parser.add_argument('--valid_epoch', default=300, type=int, help='validation frequency.')
    parser.add_argument('--lr', default=0.001, type=float, help='learning rate.')
    parser.add_argument('--save_epoch', default=30, type=int, help='model saving frequency.')
    parser.add_argument('--clip_gradient', default=10.0, type=float, help='for gradient crop.')

    # Test Params
    parser.add_argument('--test_batch_size', default=1, type=int,
                        help='test batch size, it needs to be set to 1 when using IM module.')
    parser.add_argument('--beam_size', default=100, type=int, help='the beam number of the beam search.')
    parser.add_argument('--test_inductive', action='store_true', help='whether to verify inductive inference performance.')
    parser.add_argument('--IM', action='store_true', help='whether to use IM module.')
    parser.add_argument('--IIM', action='store_true', help='whether to use Inference IM module (original implementation).')
    parser.add_argument('--mu', default=0.5, type=float, help='the hyperparameter of IM module.')

    # Agent Params
    parser.add_argument('--ent_dim', default=100, type=int, help='Embedding dimension of the entities')
    parser.add_argument('--rel_dim', default=100, type=int, help='Embedding dimension of the relations')
    parser.add_argument('--state_dim', default=100, type=int, help='dimension of the LSTM hidden state')
    parser.add_argument('--hidden_dim', default=100, type=int, help='dimension of the MLP hidden layer')
    parser.add_argument('--time_dim', default=20, type=int, help='Embedding dimension of the timestamps')
    parser.add_argument('--entities_embeds_method', default='dynamic', type=str,
                        help='representation method of the entities, dynamic or static')
    parser.add_argument('--num_layers', default=2, type=int, help='Number of transformer layers')
    parser.add_argument('--nheads', default=2, type=int, help='Number of attention heads')


    # Environment Params
    parser.add_argument('--state_actions_path', default='state_actions_space.pkl', type=str,
                        help='the file stores preprocessed candidate action array.')

    # Episode Params
    parser.add_argument('--path_length', default=3, type=int, help='the agent search path length.')
    parser.add_argument('--max_action_num', default=50, type=int, help='the max candidate actions number.')

    # Policy Gradient Params
    parser.add_argument('--Lambda', default=0.0, type=float, help='update rate of baseline.')
    parser.add_argument('--Gamma', default=0.95, type=float, help='discount factor of Bellman Eq.')
    parser.add_argument('--Ita', default=0.01, type=float, help='regular proportionality constant.')
    parser.add_argument('--Zita', default=0.9, type=float, help='attenuation factor of entropy regular term.')

    # reward shaping params
    parser.add_argument('--reward_shaping', action='store_true', help='whether to use reward shaping.')
    parser.add_argument('--time_span', default=15, type=int, help='24 for ICEWS, 1 for WIKI and YAGO')
    parser.add_argument('--alphas_pkl', default='dirchlet_alphas.pkl', type=str,
                        help='the file storing the alpha parameters of the Dirichlet distribution.')
    parser.add_argument('--k', default=300, type=int, help='statistics recent K historical snapshots.')

    parser.add_argument('--prefix', default='None', type=str, help='Prefix for the experiment name that contains other hyperparameters')

    args = parser.parse_args(args)
    return parse_prefix(args.prefix, args)



def get_model_config(args, num_ent, num_rel):
    config = {
        'cuda': args.cuda,  # whether to use GPU or not.
        'batch_size': args.batch_size,  # training batch size.
        'num_ent': num_ent,  # number of entities
        'num_rel': num_rel,  # number of relations
        'ent_dim': args.ent_dim,  # Embedding dimension of the entities
        'rel_dim': args.rel_dim,  # Embedding dimension of the relations
        'time_dim': args.time_dim,  # Embedding dimension of the timestamps
        'state_dim': args.state_dim,  # dimension of the LSTM hidden state
        'action_dim': args.ent_dim + args.rel_dim,  # dimension of the actions
        'mlp_input_dim': args.ent_dim + args.rel_dim + args.state_dim,  # dimension of the input of the MLP
        'mlp_hidden_dim': args.hidden_dim,  # dimension of the MLP hidden layer
        'path_length': args.path_length,  # agent search path length
        'max_action_num': args.max_action_num,  # max candidate action number
        'lambda': args.Lambda,  # update rate of baseline
        'gamma': args.Gamma,  # discount factor of Bellman Eq.
        'ita': args.Ita,  # regular proportionality constant
        'zita': args.Zita,  # attenuation factor of entropy regular term
        'mu': args.mu, # enhancement layer parameter
        'beam_size': args.beam_size,  # beam size for beam search
        'entities_embeds_method': args.entities_embeds_method,  # default: 'dynamic', otherwise static encoder will be used
        'num_layers': args.num_layers,  # number of transformer layers
        'nheads': args.nheads, # number of attention heads
        'IM': args.IM,  # whether to use IM module
        'lmbda': getattr(args, 'lmbda', 1),
        'degree': getattr(args, 'degree', None),
    }
    return config

def get_data_loader(trainDataset, args, step, weights=None):
    if weights is None:
        if args.er or getattr(args, 'weighted', False):
            assert False, 'Weights must be provided for ER or tail sampling'
        sampler = None
    else:
        if getattr(args, 'weighted', False):
            sampler = DualSampler(weights, fraction_tail=args.fraction)
            # sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
        elif args.er:
            sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
        else:
            assert False, 'Weights provided but no sampling method specified'


    data_loader = DataLoader(
        trainDataset,
        batch_size=args.batch_size,
        shuffle= (sampler is None),
        num_workers=args.num_workers,
        sampler=sampler
    )

    return data_loader


def continual_main(args):
    #######################Set Seed#################################
    set_seed(100)
    exp_name=args.data_path.split('/')[-1]+'_'+args.save_path.split('/')[-1]
    wandb.init(project='TiTer', config=args, name=exp_name)
    #######################Set Logger#################################
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    if args.cuda and torch.cuda.is_available():
        args.cuda = True
    else:
        args.cuda = False
    set_logger(args)

    baseData = None

    if getattr(args, 'augment', False):
        edge_augmentor = EdgeAugmenter(data_dir=args.save_path, augmented_edges_path=args.augment_fname)

    for step in range(args.start, args.end):
        snapshot_dir = 'g%d'%step
        max_epoch = args.max_epochs

        #######################Create DataLoader#################################
        train_path = os.path.join(args.data_path, '%s/train.txt' % snapshot_dir)
        test_path = os.path.join(args.data_path, '%s/test.txt' % snapshot_dir)
        valid_path = os.path.join(args.data_path, '%s/valid.txt' % snapshot_dir)
        stat_path = os.path.join(args.data_path, 'stat.txt')

        if baseData is None:
            baseData = baseDataset(train_path, test_path, stat_path, valid_path, args)
        else:
            baseData.update(train_path, test_path, valid_path)

        if getattr(args, 'augment', False) and step > 1:
            augmented_edges = edge_augmentor.load_generated_edges(step=step-1)
            # import IPython
            # IPython.embed()
            baseData.add_augmented_edges(augmented_edges)

        if getattr(args, 'IM', False):
            trainRelAdj = baseData.getBatchRelAdj(baseData.trainQuadruples, getattr(args, 'neighbors', 10))
            validRelAdj = baseData.getBatchRelAdj(baseData.validQuadruples, getattr(args, 'neighbors', 10))
            testRelAdj = baseData.getBatchRelAdj(baseData.testQuadruples, getattr(args, 'neighbors', 10))
        else:
            trainRelAdj = []
            validRelAdj = []
            testRelAdj = []


        trainDataset = QuadruplesDataset(baseData.trainQuadruples, baseData.num_r, trainRelAdj, args.IM)

        weights = None
        if args.er:
            replayer = ExperienceReplay(args.max_buffer_size, args.data_path, args.buffer_fname)
            if step > 1:
                replayer.load_buffer(args.save_path, args.buffer_fname)
                memoryRelAdj = baseData.getBatchRelAdj(replayer.buffer, getattr(args, 'neighbors', 10))
                memoryDataset = QuadruplesDataset(replayer.buffer, baseData.num_r, memoryRelAdj, args.IM)
                weights = [args.er_batch_ratio] * memoryDataset.__len__() + [
                    1 - args.er_batch_ratio] * trainDataset.__len__()
                trainDataset = ConcatDataset([memoryDataset, trainDataset])
                logging.info('Step %d, Using Experience Replay with batch ratio %f'%(step, args.er_batch_ratio))

        if args.ewc:
            ewc = EWC(os.path.join(args.save_path, args.ewc_params_path))
            ewc_args = {
                'ewc': ewc,
                'step': step
            }


        if getattr(args, 'weighted', False):
            weights = baseData.getWeights(trainDataset)


        train_dataloader = get_data_loader(trainDataset,args, step, weights=weights)

        validDataset = QuadruplesDataset(baseData.validQuadruples, baseData.num_r, validRelAdj, args.IM)
        valid_dataloader = DataLoader(
            validDataset,
            batch_size=args.test_batch_size,
            shuffle=False,
            num_workers=args.num_workers,
        )



        testDataset = QuadruplesDataset(baseData.testQuadruples, baseData.num_r, testRelAdj, args.IM)
        test_dataloader = DataLoader(
            testDataset,
            batch_size=args.test_batch_size,
            shuffle=False,
            num_workers=args.num_workers,
        )

        if step == 1:
            epoch_ratio = 1
            g1_len = trainDataset.__len__()
        else:
            epoch_ratio = trainDataset.__len__() / g1_len

        ######################Creat the agent and the environment###########################
        config = get_model_config(args, baseData.num_e, baseData.num_r)
        logging.info(config)
        logging.info(args)

        # creat the agent
        agent = Agent(config)

        # creat the environment
        # args.state_actions_path = os.path.join(snapshot_dir, args.state_actions_path)
        if not os.path.exists(os.path.join(args.data_path, args.state_actions_path)):
            state_action_space = None
            logging.info('State action space not found')
        else:
            state_action_space = pickle.load(open(os.path.join(args.data_path, args.state_actions_path), 'rb'))


            logging.info('State action space loaded from {}'.format(args.state_actions_path))
        env = Env(baseData.allQuadruples, config, state_action_space)

        if getattr(args, 'augment', False) and step > 1:
            env.update_state_actions_space(baseData.augmented_edges)

        # enhancement_layer = EnhancementLayer(ent_embeds=agent.ent_embs,
        #                                      rel_embeds=agent.rel_embs,
        #                                      degree=baseData.degree,
        #                                      temporal_rel_adj=baseData.tempRelAdj,
        #                                      config=config)
        enhancement_layer = EnhancementLayer(ent_embeds=agent.ent_embs,
                                             rel_embeds=agent.rel_embs,
                                             degree=baseData.degree,
                                             time_span=args.time_span,
                                             config=config)


        if args.cuda:
            enhancement_layer = enhancement_layer.cuda()

        # Create episode controller
        episode = Episode(env, agent, enhancement_layer, config)
        if args.cuda:
            # import IPython
            # IPython.embed()
            episode = episode.cuda()
        pg = PG(config)  # Policy Gradient
        optimizer = torch.optim.Adam(episode.parameters(), lr=args.lr, weight_decay=0.00001)

        wandb.watch(episode, log='all')

        # Load the model parameters
        if step > 1:
            checkpoint_path = os.path.join(args.load_path, 'checkpoint_final_step_%d.pth'%(step-1))
        else:
            checkpoint_path = os.path.join(args.load_path, 'checkpoint_%d_step_%d.pth'%(args.start_epoch, step))
        if os.path.isfile(checkpoint_path):
            params = torch.load(checkpoint_path)
            episode.load_state_dict(params['model_state_dict'])
            optimizer.load_state_dict(params['optimizer_state_dict'])
            logging.info('Checkpoint loaded from {}'.format(checkpoint_path))

        ######################Training and Testing###########################
        if args.reward_shaping:
            alphas = pickle.load(open(os.path.join(args.data_path, args.alphas_pkl), 'rb'))
            distributions = Dirichlet(alphas, args.k)
        else:
            distributions = None
        trainer = Trainer(episode, pg, optimizer, args, distributions)
        tester = Tester(episode, args, baseData.seen_entities, baseData.RelEntCooccurrence)
        if args.do_train:
            logging.info('Start Training......')
            for i in range(0, max(1, int(max_epoch*epoch_ratio))):
                if step == 1 and i < args.start_epoch:
                    logging.info('Skipping epoch %d'%(i))
                    for _ in train_dataloader:
                        pass
                    continue

                loss, reward, ewc_loss = trainer.train_epoch(train_dataloader, trainDataset.__len__(), ewc_args=ewc_args if args.ewc and step > 1 else None)
                logging.info('Step {} Epoch {}/{} Loss: {}, reward: {}'.format(step, i, max_epoch, loss, reward))
                wandb.log({'Training Loss': loss, 'Training Reward': reward, 'EWC Loss': ewc_loss})

                if i % args.save_epoch == 0 and i != 0:
                    trainer.save_model('checkpoint_{}_step_{}.pth'.format(i, step))
                    logging.info('Save Model in {}'.format(args.save_path))

                if i % args.valid_epoch == 0 and i != 0:
                    logging.info('Start Val......')
                    metrics, _ = tester.test(valid_dataloader,
                                          validDataset.__len__(),
                                          baseData.skip_dict,
                                          config['num_ent'])
                    for mode in metrics.keys():
                        logging.info('{} at epoch {}: {}'.format(mode, i, metrics[mode]))
                        wandb.log({mode: metrics[mode]})

            trainer.save_model('checkpoint_step_{}.pth'.format(step))
            logging.info('Save Model in {}'.format(args.save_path))

            if args.er:
                replayer.update_buffer(baseData.trainQuadruples, args.er_sample_ratio)
                replayer.save_buffer(args.save_path, args.buffer_fname)

            if args.ewc:
                logging.info('Estimating Fisher......')
                ewc.estimate_fisher(model=episode, data_loader=train_dataloader, run_batch_fn=trainer.run_batch, task_id=step)
                logging.info('Saving EWC......')
                ewc.save_dicts()

        if args.do_test:
            logging.info('Start Testing......')
            metrics, _ = tester.test(test_dataloader,
                                  testDataset.__len__(),
                                  baseData.skip_dict,
                                  config['num_ent'])
            for mode in metrics.keys():
                logging.info('Test {} : {}'.format(mode, metrics[mode]))

        if args.train_on_test:
            logging.info('Train on valid/test....')

            postRelAdj = baseData.getBatchRelAdj(baseData.validQuadruples + baseData.testQuadruples, getattr(args, 'neighbors', 10))
            postTrainDataset = QuadruplesDataset(baseData.validQuadruples + baseData.testQuadruples , baseData.num_r, postRelAdj, args.IM)
            if getattr(args, 'weighted', False):
                weights = baseData.getWeights(postTrainDataset)
            else:
                weights = None
            post_train_dataloader = get_data_loader(postTrainDataset, args, step, weights=weights)
            # post_train_dataloader = DataLoader(
            #     postTrainDataset,
            #     batch_size=args.batch_size,
            #     shuffle=True,
            #     num_workers=args.num_workers,
            # )

            epoch_ratio =  postTrainDataset.__len__() / g1_len
            for i in range(max(1, int(args.max_epochs*epoch_ratio))):
                loss, reward, ewc_loss = trainer.train_epoch(post_train_dataloader, postTrainDataset.__len__())
                logging.info('Step {} Epoch {}/{} Loss: {}, reward: {}'.format(step, i, args.max_epochs, loss, reward))
                wandb.log({'Training Loss': loss, 'Training Reward': reward})

            trainer.save_model('checkpoint_final_step_{}.pth'.format(step))
            logging.info('Save Model in {}'.format(args.save_path))

        if getattr(args, 'augment', False):
            if getattr(args, 'name', 'all_quads') != 'current':
                quads = baseData.allQuadruples
            else:
                quads = baseData.trainQuadruples + baseData.testQuadruples + baseData.validQuadruples
            if step > 1:
                quads = quads[:-len(augmented_edges)]

            new_edges = edge_augmentor.add_edges_based_on_relation(episode, all_quads= quads,threshold=args.threshold, num_r=config['num_rel'])
            # import IPython
            # IPython.embed()
            edge_augmentor.save_generated_edges(new_edges, step)

if __name__ == '__main__':
    args = parse_args()
    print('hi')
    continual_main(args)



