from os.path import abspath, dirname
import numpy as np
from copy import copy
from collections import defaultdict
from tqdm import tqdm
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import tensor, from_numpy, no_grad, save, load, arange
from torch.autograd import Variable
import torch.optim as optim

# user module imports
from utils.terminal_utils import logout, log_train
import utils.data_utils as data_utils
from utils.buffer import Buffer
from models.pytorch_modelsize import SizeEstimator
import models.standard_models as std_models
import models.er_models as er_models

class ERTrainBatchProcessor:
    def __init__(self, cmd_args):
        self.args = copy(cmd_args)
        self.dataset = data_utils.TripleDataset(self.args.dataset, self.args.neg_ratio, self.args.log_dir)
        self.dataset.load_triple_set(self.args.triplet2id)
        self.dataset.build_rel_triplets_dict()
        self.dataset.load_known_ent_set()
        self.dataset.load_known_rel_set()
        self.dataset.load_task_unique_rels()
        self.data_loader = DataLoader(self.dataset,
                                      shuffle=True,
                                      batch_size=self.args.batch_size,
                                      num_workers=self.args.num_workers,
                                      collate_fn=collate_batch,
                                      pin_memory=True)
        self.epoch_buffer = Buffer(cmd_args.buffer_max_rel, cmd_args.device)

        self.epad_id = self.dataset.ent_num
        self.rpad_id = self.dataset.rel_num

        self.build_connection(max_=cmd_args.max_neighbor)
        
    def reset_data_loader(self):
        self.data_loader = DataLoader(self.dataset,
                                      shuffle=True,
                                      batch_size=self.args.batch_size,
                                      num_workers=self.args.num_workers,
                                      collate_fn=collate_batch,
                                      pin_memory=True)
    
    def build_connection(self, max_=100):
        self.connections = np.ones((self.dataset.ent_num, max_, 2)).astype(int)
        self.connections[:, :, 0] = self.rpad_id
        self.connections[:, :, 1] = self.epad_id
        self.e1_degrees = np.zeros(self.dataset.ent_num).astype(int)
        
        self.e1_rele2 = defaultdict(list)
        data_path = abspath(dirname(dirname(__file__))) + "/datasets/" + self.args.dataset
        with open(data_path + '/path_graph.txt') as f:
            lines = f.readlines()
            for line in tqdm(lines):
                e1,rel,e2,_ = line.rstrip().split('\t')
                self.e1_rele2[e1].append((self.dataset.r2i[rel], self.dataset.e2i[e2]))
                self.e1_rele2[e2].append((self.dataset.r2i[rel+'_inv'], self.dataset.e2i[e1]))

        with open(data_path + '/train2id.txt') as f:
            lines = f.readlines()
            for line in tqdm(lines[1:]):
                e1,rel,e2 = line.rstrip().split('\t')
                self.e1_rele2[self.dataset.i2e[int(e1)]].append((int(rel),int(e2)))
                self.e1_rele2[self.dataset.i2e[int(e2)]].append((int(rel)+1, int(e1))) # +1 for _inv
                                                                
        for ent, id_ in self.dataset.e2i.items():
            neighbors = self.e1_rele2[ent]
            if len(neighbors) > max_:
                neighbors = neighbors[:max_]
            for idx, _ in enumerate(neighbors):
                self.connections[id_, idx, 0] = _[0]
                self.connections[id_, idx, 1] = _[1]
            
            self.e1_degrees[id_] = len(neighbors)
        return
    
    def get_meta(self, left, right):
        left_connections = Variable(torch.LongTensor(np.stack([self.connections[_, :, :] for _ in left], axis=0))).to(self.args.device)
        left_degrees = Variable(torch.LongTensor([self.e1_degrees[_] for _ in left])).to(self.args.device)
        right_connections = Variable(torch.LongTensor(np.stack([self.connections[_, :, :] for _ in right], axis=0))).to(self.args.device)
        right_degrees = Variable(torch.LongTensor([self.e1_degrees[_] for _ in right])).to(self.args.device)
    
        return (left_connections, left_degrees, right_connections, right_degrees)
   
    def get_negative_meta(self, meta):
        left_connections, left_degrees, right_connections, right_degrees = meta
        false_left_connections, false_right_connections = left_connections.clone(), right_connections.clone()
        triplet_num, neighbor_num = false_left_connections.shape[0], false_left_connections.shape[1]
        false_left_connections[:, :, 0] = torch.randint(0, len(self.dataset.r2i), 
                                                        (false_left_connections.shape[0], false_left_connections.shape[1]))
        false_right_connections[:, :, 0] = torch.randint(0, len(self.dataset.r2i), 
                                                         (false_left_connections.shape[0], false_left_connections.shape[1]))
        
        return (false_left_connections, left_degrees, false_right_connections, right_degrees)
        
    def process_epoch(self, model, optimizer):
        if not model.training:
            model.train()

        total_loss = 0.0
        for idx_b, batch in enumerate(self.data_loader):
            bh, br, bt = batch
            
            # get meta for true triplets
            batch_meta = self.get_meta(bh[0:len(bh):self.args.neg_ratio+1], bt[0:len(bt):self.args.neg_ratio+1])
            negative_meta = []
            for i in range(self.args.max_nn_meta):
                negative_meta.append(self.get_negative_meta(batch_meta))
                
            optimizer.zero_grad()
            loss_tuple, sel_triplets = model.forward(self.dataset.task_unique_rels.to(self.args.device),
                                                     bh.contiguous().to(self.args.device),
                                                     br.contiguous().to(self.args.device),
                                                     bt.contiguous().to(self.args.device),
                                                     batch_meta, negative_meta)
            batch_loss, loss_margin, loss_info, loss_l2 = loss_tuple
            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item()
            
            # update SI variables
            model.update_W()
            
            # update epoch_buffer
            for triplet in sel_triplets:
                if triplet.shape[0] > 0:
                    self.epoch_buffer.add_data(triplet[0][1], torch.split(triplet, triplet.shape[0])[0])
            print('Total loss: {}, margin: {}, info: {}, l2: {}'.format(batch_loss, loss_margin, loss_info, loss_l2))
        return total_loss, self.epoch_buffer

class ValidBatchProcessor:
    def __init__(self, cmd_args):
        self.args = copy(cmd_args)
        self.dataset = data_utils.TripleDataset(self.args.dataset, self.args.neg_ratio)
        self.dataset.load_triple_set(self.args.triplet2id)
        self.dataset.load_mask(cmd_args.dataset_fps)
        self.dataset.load_known_ent_set()
        self.dataset.load_known_rel_set()
        self.batch_size = 10
        self.data_loader = DataLoader(self.dataset,
                                      shuffle=False,
                                      batch_size=self.batch_size,
                                      num_workers=self.args.num_workers,
                                      collate_fn=collate_batch,
                                      pin_memory=True)

    def process_epoch(self, model):
        if model.training:
            model.eval()

        h_ranks = np.ndarray(shape=0, dtype=np.float64)
        t_ranks = np.ndarray(shape=0, dtype=np.float64)
        with no_grad():
            for idx_b, batch in enumerate(self.data_loader):
                if self.args.cuda and torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # get ranks for each triple in the batch
                bh, br, bt = batch
                h_ranks = np.append(h_ranks, self._rank_head(model, bh, br, bt), axis=0)
                t_ranks = np.append(t_ranks, self._rank_tail(model, bh, br, bt), axis=0)

        # calculate hits & mrr
        hits10_h = np.count_nonzero(h_ranks <= 10) / len(h_ranks)
        hits10_t = np.count_nonzero(t_ranks <= 10) / len(t_ranks)
        hits10 = (hits10_h + hits10_t) / 2.0
        mrr = np.mean(np.concatenate((1 / h_ranks, 1 / t_ranks), axis=0))

        return hits10, mrr
    
    def _rank_head(self, model, h, r, t):
        rank_heads = Variable(from_numpy(np.arange(len(self.dataset.e2i)))).repeat(h.shape[0], 1)
        scores = model.predict(rank_heads.contiguous().to(self.args.device),
                               r.unsqueeze(-1).contiguous().to(self.args.device),
                               t.unsqueeze(-1).contiguous().to(self.args.device))
        ranks = []
        known_ents = np.asarray(self.dataset.known_ents, dtype=np.int64)
        for i in range(scores.shape[0]):
            scores_ = copy(scores[i, :])
            scores_ = np.stack((scores_, np.arange(len(self.dataset.e2i))), axis=-1)
            if (int(r[i].numpy()), int(t[i].numpy())) in self.dataset.h_mask:
                h_mask = copy(self.dataset.h_mask[(int(r[i].numpy()), int(t[i].numpy()))])
                h_mask.remove(int(h[i].numpy()))
                ents = known_ents[np.isin(known_ents, h_mask, True, True)]
            else:
                ents = known_ents
            filtered_scores = scores_[np.isin(scores_[:, -1], ents, True), :]
            filtered_ent_idx = int(np.where(filtered_scores[:, -1] == int(h[i].numpy()))[0])
            ranks_ = np.argsort(filtered_scores[:, 0], 0)
            ranks.append(int(np.where(ranks_ == filtered_ent_idx)[0])+1)
        return ranks

    def _rank_tail(self, model, h, r, t):
        rank_tails = Variable(from_numpy(np.arange(len(self.dataset.e2i)))).repeat(t.shape[0], 1)
        scores = model.predict(h.unsqueeze(-1).contiguous().to(self.args.device),
                               r.unsqueeze(-1).contiguous().to(self.args.device),
                               rank_tails.contiguous().to(self.args.device))
        ranks = []
        known_ents = np.asarray(self.dataset.known_ents, dtype=np.int64)
        for i in range(scores.shape[0]):
            scores_ = copy(scores[i, :])
            scores_ = np.stack((scores_, np.arange(len(self.dataset.e2i))), axis=-1)
            if (int(h[i].numpy()), int(r[i].numpy())) in self.dataset.t_mask:
                t_mask = copy(self.dataset.t_mask[(int(h[i].numpy()), int(r[i].numpy()))])
                t_mask.remove(int(t[i].numpy()))
                ents = known_ents[np.isin(known_ents, t_mask, True, True)]
            else:
                ents = known_ents
            filtered_scores = scores_[np.isin(scores_[:, -1], ents, True), :]
            filtered_ent_idx = int(np.where(filtered_scores[:, -1] == int(t[i].numpy()))[0])
            ranks_ = np.argsort(filtered_scores[:, 0], 0)
            ranks.append(int(np.where(ranks_ == filtered_ent_idx)[0])+1)
        return ranks

def collate_batch(batch):
    batch = tensor(batch)
    batch_h = batch[:, :, 0].flatten()
    batch_r = batch[:, :, 1].flatten()
    batch_t = batch[:, :, 2].flatten()
    return batch_h, batch_r, batch_t


def init_model(args):
    model = None
    model = er_models.TransE(args.ent_num, args.rel_num, args.hidden_size, args.margin,
                             args.neg_ratio, args.batch_size, args.topk, args.device, args.epad_id, args.rpad_id,
                             args.coeff_info, args.coeff_l2)
    model.to(args.device, non_blocking=True)
    return model


def init_optimizer(args, model):
    optim_params = [param for param in model.parameters() if param.requires_grad]
    optimizer = None
    if args.opt_method == "adagrad":
        lr = args.lr
        optimizer = optim.Adagrad(optim_params, lr=lr)
    elif args.opt_method == "adadelta":
        lr = args.lr
        optimizer = optim.Adadelta(optim_params, lr=lr)
    elif args.opt_method == "adam":
        optimizer = optim.Adam(optim_params, lr=args.lr, weight_decay=args.wd)
    elif args.opt_method == "sgd":
        lr = args.lr
        optimizer = optim.SGD(optim_params, lr=lr)
    else:
        logout("Optimization options are 'adagrad','adadelta','adam','sgd'", "f")
        exit()

    return optimizer


def save_model(args, model):
    checkpoints_fp = abspath(dirname(dirname(__file__))) + "/ckps/"

    save_checkpoint(model.state_dict(), checkpoints_fp + args.checkpoint_name + '/sess_' + args.sess)


def save_checkpoint(model_params,filename):
    try:
        torch.save(model_params, filename)
        # logout('Written to: ' + filename)
    except Exception as e:
        logout("Could not save: " + filename, "w")
        raise e


def load_model(args, model):
    checkpoints_fp = abspath(dirname(dirname((__file__)))) + "/ckps/"
                             
    model = load_checkpoint(model, checkpoints_fp + args.checkpoint_name + '/sess_' + args.sess)
    return model


def load_checkpoint(model, filename):
    try:
        model.load_state_dict(load(filename), strict=False)
    except Exception as e:
        logout("Could not load: " + filename, "w")
        raise e
    return model


def evaluate_model(args, sess, batch_processors, model):
    performances = np.ndarray(shape=(0, 2))
    for valid_sess in range(sess+1):
        eval_bp = batch_processors[valid_sess]

        performance = eval_bp.process_epoch(model)
        performances = np.append(performances, [performance], axis=0)
    return performances


class EarlyStopTracker:
    def __init__(self, args):
        self.args = args
        self.num_epoch = args.num_epochs
        self.epoch = 0
        self.valid_freq = args.valid_freq
        self.patience = args.patience
        self.early_stop_trigger = -int(self.patience / self.valid_freq)
        self.last_early_stop_value = 0.0
        self.best_performances = None
        self.best_measure = 0.0
        self.best_epoch = None

    def continue_training(self):
        return not bool(self.epoch > self.num_epoch or self.early_stop_trigger > 0)

    def get_epoch(self):
        return self.epoch

    def validate(self):
        return bool(self.epoch % self.valid_freq == 0)

    def update_best(self, sess, performances, model):
        measure = performances[sess, 1]
        # checks for new best model and saves if so
        if measure > self.best_measure:
            self.best_measure = copy(measure)
            self.best_epoch = copy(self.epoch)
            self.best_performances = np.copy(performances)
            save_model(self.args, model)
        # checks for reset of early stop trigger
        if measure - 0.01 > self.last_early_stop_value:
            self.last_early_stop_value = copy(measure)
            self.early_stop_trigger = -int(self.patience / self.valid_freq)
        else:
            self.early_stop_trigger += 1
        # adjusts valid frequency throughout training
        if self.epoch >= 400:
            self.early_stop_trigger = self.early_stop_trigger * self.valid_freq / 50.0
            self.valid_freq = 50
        elif self.epoch >= 200:
            self.early_stop_trigger = self.early_stop_trigger * self.valid_freq / 25.0
            self.valid_freq = 25
        elif self.epoch >= 50:
            self.early_stop_trigger = self.early_stop_trigger * self.valid_freq / 10.0
            self.valid_freq = 10

    def step_epoch(self):
        self.epoch += 1

    def get_best(self):
        return self.best_performances, self.best_epoch



if __name__ == "__main__":
    # TODO add unit tests below
    pass
