# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import time
import torch
from torch import nn
import os
from tqdm import tqdm
from datasets import Dataset
from models import CP, ComplEx, TransE, RESCAL, TuckER
from regularizers import F2, N3
from utils import avg_both, setup_optimizer, get_git_revision_hash, set_seed
import json
import numpy as np

def setup_ds(opt):
    dataset_opt = {k: v for k, v in opt.items() if k in ['dataset', 'device', 'cache_eval', 'reciprocal']}
    dataset = Dataset(dataset_opt)
    return dataset


def setup_model(opt):
    if opt['model'] == 'TransE':
        model = TransE(opt['size'], opt['rank'], opt['init'])
    elif opt['model'] == 'ComplEx':
        model = ComplEx(opt['size'], opt['rank'], opt['init'])
    elif opt['model'] == 'TuckER':
        model = TuckER(opt['size'], opt['rank'], opt['rank_r'], opt['init'], opt['dropout'])
    elif opt['model'] == 'RESCAL':
        model = RESCAL(opt['size'], opt['rank'], opt['init'])
    elif opt['model'] == 'CP':
        model = CP(opt['size'], opt['rank'], opt['init'])
    model.to(opt['device'])
    return model


def setup_loss(opt):
    if opt['world'] == 'sLCWA+bpr':
        loss = nn.BCEWithLogitsLoss(reduction='mean')
    elif opt['world'] == 'sLCWA+set':
        pass
    elif opt['world'] == 'LCWA':
        loss = nn.CrossEntropyLoss(reduction='mean')
    return loss


def setup_regularizer(opt):
    if opt['regularizer'] == 'F2':
        regularizer =  F2(opt['lmbda'])
    elif opt['regularizer'] == 'N3':
        regularizer = N3(opt['lmbda'])
    return regularizer


def _set_exp_alias(opt):
    suffix = '{}_{}_Rank{}_Reg{}_Lmbda{}_W{}'.format(opt['dataset'], opt['model'], opt['rank'], opt['regularizer'], opt['lmbda'], opt['w_rel'])
    alias = opt['alias'] + suffix
    return alias


def _set_cache_path(path_template, dataset, alias):
    if path_template is not None:
        cache_path = path_template.format(dataset=dataset, alias=alias)
        if not os.path.exists(cache_path):
            os.makedirs(cache_path, exist_ok=True)
    else:
        cache_path = None
    return cache_path


class KBCEngine(object):
    def __init__(self, opt):
        self.seed = opt['seed']
        set_seed(int(self.seed))
        self.alias = _set_exp_alias(opt)
        self.cache_eval = _set_cache_path(opt['cache_eval'], opt['dataset'], self.alias)
        self.model_cache_path = _set_cache_path(opt['model_cache_path'], opt['dataset'], self.alias)
        opt['cache_eval'] = self.cache_eval
        # dataset
        self.dataset = setup_ds(opt)
        opt['size'] = self.dataset.get_shape()
        self.batch_size = opt['batch_size']
        # model
        self.model = setup_model(opt)
        self.optimizer = setup_optimizer(self.model, opt['optimizer'], opt['learning_rate'], opt['decay1'], opt['decay2'])
        self.loss = setup_loss(opt)
        opt['loss'] = self.loss
        # regularizer
        self.regularizer = setup_regularizer(opt)
        self.device = opt['device']
        self.max_epochs = opt['max_epochs']
        self.world = opt['world']
        self.num_neg = opt['num_neg']
        self.score_rel = opt['score_rel']
        self.score_rhs = opt['score_rhs']
        self.score_lhs = opt['score_lhs']
        self.w_rel = opt['w_rel']
        self.w_lhs = opt['w_lhs']
        self.opt = opt
        self._epoch_id = 0
        self.writer = open(self.model_cache_path+"/train.log", "a+")
        '''
        wandb.init(project="ssl-relation-prediction", 
                    group=opt['experiment_id'], 
                    tags=opt['run_tags'],
                    notes=opt['run_notes'])
        wandb.config.update(opt)
        wandb.watch(self.model, log='all', log_freq=10000)
        wandb.run.summary['is_done'] = False
        '''
        # print('Git commit ID: {}'.format(get_git_revision_hash()))
        
    def episode(self): #KGE Training Function
        time_start_1 = time.time()
        best_valid_mrr, init_epoch_id, step_idx = 0, 0, 0
        exp_train_sampler = self.dataset.get_sampler('train')
        exp_num_train_sampler = self.dataset.get_sampler('train_num')
        exp_np_train_sampler = self.dataset.get_sampler('train_np')
        for e in range(init_epoch_id, self.max_epochs):
            # wandb.run.summary['epoch_id'] = e
            self.model.train()
            pbar = tqdm(total=exp_train_sampler.size)
            while exp_train_sampler.is_epoch(e): # iterate through all batchs inside an epoch
                pbar.update(self.batch_size)
                if self.world == 'LCWA':
                    input_batch_train = exp_train_sampler.batchify(self.batch_size,
                                                                    self.device)
                    ali_num = self.dataset.examples_train.shape[0]/self.dataset.examples_num_train.shape[0]
                    input_batch_num_train = exp_num_train_sampler.batchify(int(self.batch_size/ali_num),
                                                                    self.device)
                    ali_np = self.dataset.examples_train.shape[0]/self.dataset.examples_np_train.shape[0]
                    input_batch_np_train = exp_np_train_sampler.batchify(int(self.batch_size/ali_np),
                                                                    self.device)
                    predictions, factors,factors_num,factors_np= self.model.forward(input_batch_train,input_batch_num_train,input_batch_np_train,score_rel=self.score_rel, score_rhs=self.score_rhs, score_lhs=self.score_lhs)
                    if self.score_rel and self.score_rhs and self.score_lhs:
                        # print('----1----')
                        l_fit = self.loss(predictions[0], input_batch_train[:, 2]) \
                                + self.w_rel * self.loss(predictions[1], input_batch_train[:, 1]) \
                                + self.w_lhs * self.loss(predictions[2], input_batch_train[:, 0])
                        lossrhs = self.loss(predictions[3], input_batch_num_train[:, 2])
                        lossrel = self.w_rel * self.loss(predictions[4], input_batch_num_train[:, 1])
                        losslhs = self.w_lhs * self.loss(predictions[5], input_batch_num_train[:, 0])
                        att_fit = 5*(lossrhs + lossrel+losslhs)
                        np_fit = self.loss(predictions[6], input_batch_np_train[:, 2]) \
                                + self.w_rel * self.loss(predictions[7], input_batch_np_train[:, 1]) \
                                + self.w_lhs * self.loss(predictions[8], input_batch_np_train[:, 0])
                        
                    elif self.score_rel and self.score_rhs:
                        # print('----2----')
                        l_fit = self.loss(predictions[0], input_batch_train[:, 2]) + self.w_rel * self.loss(predictions[1], input_batch_train[:, 1])
                        att_fit = self.loss(predictions[2], input_batch_num_train[:, 2]) + self.w_rel * self.loss(predictions[3], input_batch_num_train[:, 1])
                        # att_reverse_fit = self.loss(predictions[4], input_batch_num_reverse_train[:, 2].long()) + self.w_rel * self.loss(predictions[5], input_batch_num_reverse_train[:, 1].long())
                    elif self.score_lhs and self.score_rel:
                        # print('----3----')
                        pass
                    elif self.score_rhs and self.score_lhs: # standard
                        # print('----4----')
                        l_fit = self.loss(predictions[0], input_batch_train[:, 2]) + self.loss(predictions[1], input_batch_train[:, 0])
                        att_fit = self.loss(predictions[2], input_batch_num_train[:, 2]) + self.w_rel * self.loss(predictions[3], input_batch_num_train[:, 0])
                    elif self.score_rhs: # only rhs
                        # print('----5----')
                        l_fit = self.loss(predictions[0], input_batch_train[:, 2])
                        att_fit = self.loss(predictions[1], input_batch_num_train[:, 2])
                    elif self.score_rel:
                        # print('----6----')
                        l_fit = self.loss(predictions[0], input_batch_train[:, 1])
                        att_fit = self.loss(predictions[1], input_batch_num_train[:, 1])
                    elif self.score_lhs:
                        # print('----7----')
                        pass
                    l_reg, l_reg_raw, avg_lmbda = self.regularizer.penalty(input_batch_train, factors) # Note: this shouldn't be included into the computational graph of lambda update
                    att_reg, att_reg_raw, avg_lmbda_num = self.regularizer.penalty(input_batch_num_train, factors_num)
                    np_reg, np_reg_raw, avg_lmbda_np = self.regularizer.penalty(input_batch_np_train, factors_np)
                elif self.world == 'sLCWA+bpr':
                    pos_train, neg_train, label = exp_train_sampler.batchify(self.batch_size,
                                                                                self.device,
                                                                                num_neg=self.num_neg)
                    predictions, factors = self.model.forward_bpr(pos_train, neg_train)
                    l_fit = self.loss(predictions, label)
                    l_reg, l_reg_raw, avg_lmbda = self.regularizer.penalty(
                        torch.cat((pos_train, neg_train), dim=0),
                        factors)
                #Loss Function
                l = (l_fit + l_reg) + (att_fit + att_reg) + (np_fit+np_reg)
                pbar.set_postfix(Loss1=(l_fit + l_reg),Loss2 = (att_fit + att_reg),Loss3=(np_fit+np_reg))
                self.optimizer.zero_grad()
                l.backward()
                self.optimizer.step()
                '''
                if ((step_idx % 1000 == 0 and step_idx > 1000) or (step_idx <= 1000 and step_idx % 100 == 0)): # reduce logging frequency to accelerate 
                    wandb.log({'step_wise/train/l': l.item()}, step=step_idx)
                    wandb.log({'step_wise/train/l_fit': l_fit.item()}, step=step_idx)
                    wandb.log({'step_wise/train/l_reg': l_reg.item()}, step=step_idx)
                    wandb.log({'step_wise/train/l_reg_raw': l_reg_raw.item()}, step=step_idx)
                '''
                step_idx += 1
            if e % self.opt['valid'] == 0:
                self.model.eval()
                res_all, res_all_detailed = [], []
                for split in self.dataset.splits:
                    res_s = self.dataset.eval(model=self.model, 
                                              split=split, 
                                              n_queries =-1 if split != 'train' and  split != 'valid' and split != 'test' and split != 'train_num' and  split != 'train_np' and  split != 'valid_np'and  split != 'test_np' else 1000, # subsample 5000 triples for computing approximated training MRR
                                              n_epochs=e)
                    res_all.append(avg_both(res_s[0], res_s[1],split))
                    res_all_detailed.append(res_s[2])
                res = dict(zip(self.dataset.splits, res_all))
                res_detailed = dict(zip(self.dataset.splits, res_all_detailed))
                print("\t Epoch: ", e)
                self.writer.write("Epoch: "+str(e)+"\n")
                for split in self.dataset.splits:
                    print("\t {}: {}".format(split.upper(), res[split]))
                    self.writer.write("{}: {}\n".format(split.upper(), res[split]))
                    
                    '''
                    wandb.log({'step_wise/{}/mrr'.format(split): res[split]['MRR']}, step=step_idx)
                    wandb.log({'step_wise/{}/hits@1'.format(split): res[split]['hits@[1,3,10]'][0]}, step=step_idx)
                    '''
                split = 'meta_valid' if 'meta_valid' in self.dataset.splits else 'valid'
                self.model.checkpoint(model_cache_path=self.model_cache_path, epoch_id=str(e))
            now_valid_mrr = 0
            for split in self.dataset.splits:
                if "train" in split:
                    continue
                now_valid_mrr += res[split]['MRR']
            now_valid_mrr = now_valid_mrr/6
            if now_valid_mrr > best_valid_mrr:
                best_valid_mrr = now_valid_mrr
                self.model.checkpoint(model_cache_path=self.model_cache_path, epoch_id='best_valid')
                if self.opt['cache_eval'] is not None:
                    for s in self.dataset.splits:
                        for m in ['lhs', 'rhs']:
                            torch.save(res_detailed[s][m], 
                                        self.opt['cache_eval']+'{s}_{m}.pt'.format(s=s, m=m))
                    '''
                    wandb.run.summary['best_valid_mrr'] = best_valid_mrr
                    wandb.run.summary['best_valid_epoch'] = e
                    wandb.run.summary['corr_test_mrr'] = res['test']['MRR']
                    wandb.run.summary['corr_test_hits@1'] = res['test']['hits@[1,3,10]'][0]
                    wandb.run.summary['corr_test_hits@3'] = res['test']['hits@[1,3,10]'][1]
                    wandb.run.summary['corr_test_hits@10'] = res['test']['hits@[1,3,10]'][2]
                    '''
            if best_valid_mrr == 1:
                print('MRR 1, diverged!')
                break
            if best_valid_mrr > 0 and best_valid_mrr < 2e-4:
                if l_reg_raw.item() < 1e-4:
                    print('0 embedding weight, diverged!')
                    break
        self.model.eval()
        mrrs, hits, _ = self.dataset.eval(self.model, 'test', -1)
        print("\n\nTEST : MRR {} Hits {}".format(mrrs, hits))
        mrrs, hits, _ = self.dataset.eval(self.model, 'test_num', -1)
        print("\n\nTEST_NUM : MRR {} Hits {}".format(mrrs, hits))
        mrrs, hits, _ = self.dataset.eval(self.model, 'test_np', -1)
        print("\n\nTEST_NP : MRR {} Hits {}".format(mrrs, hits))
        self.writer.write("Best valid mrr: "+str(best_valid_mrr))