import os
import time
import random
import datetime
import argparse

import numpy as np
import torch
import torch.nn.functional as F

from model import Model
from datasets import KnowledgeGraph


def main(args):
    if args.mode == 'train':
        t0 = time.time()
        device = torch.device(args.device)
        save_dir = get_save_dir(args)
        print(save_dir)

        file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets')
        dataset = KnowledgeGraph(file_path, parse_args.dataset)
        model = Model(args.model_name, dataset.num_entity, dataset.num_relation, args.dimension, dataset.train_data, args.confounder).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, args.decay_rate)

        best_mrr = train(args, device, save_dir, dataset, model, optimizer, scheduler)
        t1 = time.time()
        print(f'time: {t1 - t0}s')
        return -best_mrr

    elif args.mode == 'test':
        device = torch.device(args.device)
        save_dir = get_save_dir(args)

        file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets')
        dataset = KnowledgeGraph(file_path, args.dataset)
        model = Model(args.model_name, dataset.num_entity, dataset.num_relation, args.dimension, dataset.train_data, args.confounder).to(device)

        state_file = os.path.join(save_dir, 'epoch_best.pth')
        if not os.path.isfile(state_file):
            raise RuntimeError('file {0} is not found'.format(state_file))
        print('load checkpoint {0}'.format(state_file))
        checkpoint = torch.load(state_file, device)
        epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        test(args, device, dataset, model, epoch, -1.0, is_test=True)

    else:
        raise RuntimeError('wrong mode')


def train(args, device, save_dir, dataset, model, optimizer, scheduler):
    best_mrr = 0.0
    best_epoch = 0
    data = dataset.train_data
    num_batch = len(data) // args.batch_size + int(len(data) % args.batch_size > 0)

    degree_h = torch.FloatTensor(dataset.degree_train_h).to(device)
    degree_h = torch.log(torch.maximum(degree_h, torch.ones_like(degree_h)))
    degree_r = torch.FloatTensor(dataset.degree_train_r).to(device)
    degree_r = torch.log(torch.maximum(degree_r, torch.ones_like(degree_r)))
    degree_t = torch.FloatTensor(dataset.degree_train_t).to(device)
    degree_t = torch.log(torch.maximum(degree_t, torch.ones_like(degree_t)))
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        total_loss = 0.
        np.random.shuffle(data)
        model.train()
        for i in range(num_batch):
            start = i * args.batch_size
            end = min((i + 1) * args.batch_size, len(data))
            batch_data = data[start:end]
            heads = torch.LongTensor(batch_data[:, 0]).to(device)
            relations = torch.LongTensor(batch_data[:, 1]).to(device)
            tails = torch.LongTensor(batch_data[:, 2]).to(device)

            scores, reg, factor1, factor2, factor3 = model(heads, relations, tails)
            if args.confounder == 'no':
                loss = F.cross_entropy(scores, tails) + args.lambda1 * reg
            elif args.confounder == 'IPS':
                weight = torch.exp(-degree_h[heads] - degree_r[relations] - degree_t[tails])
                weight = weight / weight.sum()
                loss = torch.sum(weight * F.cross_entropy(scores, tails, reduction='none')) + args.lambda1 * reg
            elif args.confounder == 'degree':
                scores = scores + args.gamma1 * degree_h[heads].unsqueeze(-1) + args.gamma2 * degree_r[relations].unsqueeze(-1) + args.gamma3 * degree_t
                loss = F.cross_entropy(scores, tails) + args.lambda1 * reg
            elif args.confounder == '1-MLP':
                scores = scores + args.gamma1 * factor1 + args.gamma2 * factor2 + args.gamma3 * factor3
                loss = F.cross_entropy(scores, tails) + args.lambda1 * reg
            else:
                raise ValueError('wrong confounder')
            total_loss += loss.item() * (end - start)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_loss /= len(data)
        scheduler.step()
        t1 = time.time()
        print('\n[train: epoch {0}], loss: {1}, time: {2}s'.format(epoch, total_loss, t1 - t0))

        if not (epoch % args.test_interval):
            metric = test(args, device, dataset, model, epoch, -1.0, is_test=False)
            test(args, device, dataset, model, epoch, -1.0, is_test=True)
            save(save_dir, epoch, model)
            if metric['mrr'] > best_mrr:
                best_mrr = metric['mrr']
                best_epoch = epoch
    print('best mrr: {} at epoch {}'.format(best_mrr, best_epoch))
    return best_mrr


def test(args, device, dataset, model, epoch, power, is_test=True):
    if is_test:
        data = dataset.test_data
        degree_h = dataset.degree_test_h
        degree_r = dataset.degree_test_r
        degree_t = dataset.degree_test_t
    else:
        data = dataset.valid_data
        degree_h = dataset.degree_valid_h
        degree_r = dataset.degree_valid_r
        degree_t = dataset.degree_valid_t
    total_weight = 0.
    num_batch = len(data) // args.batch_size + int(len(data) % args.batch_size > 0)
    metric = {'mr': 0.0, 'mrr': 0.0, 'hit@1': 0.0, 'hit@3': 0.0, 'hit@10': 0.0}

    t0 = time.time()
    model.eval()

    with torch.no_grad():
        for i in range(num_batch):
            start = i * args.batch_size
            end = min((i + 1) * args.batch_size, len(data))
            batch_data = data[start:end]
            heads = torch.LongTensor(batch_data[:, 0]).to(device)
            relations = torch.LongTensor(batch_data[:, 1]).to(device)
            tails = torch.LongTensor(batch_data[:, 2]).to(device)

            scores = model(heads, relations, tails)[0]
            scores = scores.detach().cpu().numpy()
            weight_h = degree_h[batch_data[:, 0]]
            weight_r = degree_r[batch_data[:, 1]]
            weight_t = degree_t[batch_data[:, 2]]

            for j in range(end-start):
                weight = (weight_h[j] ** power) * (weight_r[j] ** power) * (weight_t[j] ** power)
                total_weight += weight
                target = scores[j, batch_data[j][2]]
                scores[j, dataset.hr_vocab[(batch_data[j][0], batch_data[j][1])]] = -1e8
                rank = np.sum(scores[j] >= target) + 1
                metric['mr'] += (rank * weight)
                metric['mrr'] += ((1.0 / rank) * weight)
                if rank == 1:
                    metric['hit@1'] += weight
                if rank <= 3:
                    metric['hit@3'] += weight
                if rank <= 10:
                    metric['hit@10'] += weight

    metric['mr'] /= total_weight
    metric['mrr'] /= total_weight
    metric['hit@1'] /= total_weight
    metric['hit@3'] /= total_weight
    metric['hit@10'] /= total_weight
    t1 = time.time()
    print('[test: epoch {}], power: {}, mrr: {}, mr: {}, hit@1: {}, hit@3: {}, hit@10: {}, time: {}s'
          .format(epoch, power, metric['mrr'], metric['mr'], metric['hit@1'], metric['hit@3'], metric['hit@10'], t1-t0))
    return metric


def get_save_dir(args):
    if args.save_dir:
        save_dir = args.save_dir
    else:
        name = str(datetime.datetime.now())[:-7].replace(' ', '-').replace(':', '-')
        save_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'save', name)
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    return save_dir


def save(save_dir, epoch, model):
    state_path = os.path.join(save_dir, 'epoch_best.pth')
    state = {
        'epoch': epoch,
        'model': model.state_dict(),
    }
    torch.save(state, state_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Causal Inference for Knowledge Graph Completion')
    parser.add_argument('--mode', type=str, default='test',
                        choices=['train', 'test'],
                        help='mode')
    parser.add_argument('--device', type=str, default='cuda:0',
                        choices=['cuda:0', 'cpu'],
                        help='device')
    parser.add_argument('--dataset', type=str, default='FB15k-237',
                        choices=['WN18RR', 'FB15k-237', 'YAGO3-10'],
                        help='dataset')
    parser.add_argument('--model_name', type=str, default='ComplEx',
                        choices=['TransE', 'RotatE', 'DistMult', 'ComplEx'],
                        help='model name')
    parser.add_argument('--confounder', type=str, default='no',
                        choices=['no', 'degree', 'IPS', '1-MLP'],
                        help='model name')
    parser.add_argument('--save_dir', type=str, default='',
                        help='save directory')
    parser.add_argument('--test_interval', type=int, default=1,
                        help='number of epochs to test')

    parser.add_argument('--dimension', type=int, default=2048,
                        help='the dimension of embedding')
    parser.add_argument('--epochs', type=int, default=50,
                        help='number of epochs to train')
    parser.add_argument('--batch_size', type=int, default=1024,
                        help='batch size')
    parser.add_argument('--lr', type=float, default=0.005,
                        help='learning rate')
    parser.add_argument('--decay_rate', type=float, default=0.93,
                        help='decay rate of learning rate')
    parser.add_argument('--gamma1', type=float, default=0.0,
                        help='coefficient of confounder')
    parser.add_argument('--gamma2', type=float, default=0.0,
                        help='coefficient of confounder')
    parser.add_argument('--gamma3', type=float, default=0.0,
                        help='coefficient of confounder')
    parser.add_argument('--lambda1', type=float, default=0.0,
                        help='coefficient of regularization')

    parser.add_argument('--seed', type=int, default=0,
                        help='seed')
    parse_args = parser.parse_args()
    random.seed(parse_args.seed)
    np.random.seed(parse_args.seed)
    torch.manual_seed(parse_args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(parse_args.seed)
    print(parse_args.__dict__)

    main(parse_args)
