import os
import time
import random
import datetime
import argparse
import math

import numpy as np
import torch
import torch.nn.functional as F

from model import Model
from datasets import KnowledgeGraph


def main(args):
    if args.mode == 'train':
        device = torch.device(args.device)
        save_dir = get_save_dir(args)

        file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets')
        dataset = KnowledgeGraph(file_path, args.dataset)
        model = Model(args.model_name, dataset.num_entity, dataset.num_relation, args.part, args.dimension, args.regularization, args.alpha).to(device)
        optimizer = torch.optim.Adagrad(model.parameters(), lr=args.lr)

        train(args, device, save_dir, dataset, model, optimizer)

    elif args.mode == 'test':
        device = torch.device(args.device)
        save_dir = get_save_dir(args)

        file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets')
        dataset = KnowledgeGraph(file_path, args.dataset)
        model = Model(args.model_name, dataset.num_entity, dataset.num_relation, args.part, args.dimension, args.regularization, args.alpha).to(device)

        state_file = os.path.join(save_dir, 'epoch_best.pth')
        if not os.path.isfile(state_file):
            raise RuntimeError(f'file {state_file} is not found')
        print(f'load checkpoint {state_file}')
        checkpoint = torch.load(state_file, device)
        epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        test(args, device, dataset, model, epoch, is_test=True)

    else:
        raise RuntimeError('wrong mode')


def train(args, device, save_dir, dataset, model, optimizer):
    best_mrr = 0.0
    best_epoch = 0
    stop = 0

    data = dataset.train_data
    number = len(data)
    batch_size = args.batch_size
    num_batch = len(data) // batch_size + int(len(data) % batch_size > 0)

    print('start training')
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        total_loss = 0.
        total_factor1 = 0.
        total_factor2 = 0.
        total_factor3 = 0.
        total_factor4 = 0.
        np.random.shuffle(data)
        model.train()
        for i in range(num_batch):
            start = i * batch_size
            end = min((i + 1) * batch_size, len(data))
            batch_data = data[start:end]
            heads = torch.LongTensor(batch_data[:, 0]).to(device)
            relations = torch.LongTensor(batch_data[:, 1]).to(device)
            tails = torch.LongTensor(batch_data[:, 2]).to(device)
            scores, factor1, factor2, factor3, factor4 = model(heads, relations, tails)

            loss = F.cross_entropy(scores, tails) + args.lambda1 * factor1 + args.lambda2 * factor2 + args.lambda3 * factor3 + args.lambda4 * factor4
            total_loss += loss.item() * batch_data.shape[0]
            total_factor1 += factor1.item() * batch_data.shape[0]
            total_factor2 += factor2.item() * batch_data.shape[0]
            total_factor3 += factor3.item() * batch_data.shape[0]
            total_factor4 += factor4.item() * batch_data.shape[0]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_loss = total_loss / number
        total_factor1 = total_factor1 / number
        total_factor2 = total_factor2 / number
        total_factor3 = total_factor3 / number
        total_factor4 = total_factor4 / number
        t1 = time.time()
        print(f'\n[train: epoch {epoch}], loss: {total_loss}, total factor1: {total_factor1}, total factor2: {total_factor2}, total factor3: {total_factor3}, total factor4: {total_factor4}, time: {t1 - t0}s')

        if not (epoch % args.test_interval):
            metric = test(args, device, dataset, model, epoch, is_test=False)
            _ = test(args, device, dataset, model, epoch, is_test=True)
            mrr = metric['mrr']
            if mrr > best_mrr:
                best_mrr = mrr
                best_epoch = epoch
                stop = 0
                save(save_dir, epoch, model)
            else:
                stop += 1
        if stop == 50:
            break
    print(f'best mrr: {best_mrr} at epoch {best_epoch}')


def test(args, device, dataset, model, epoch, is_test=True):
    if is_test:
        data = dataset.test_data
    else:
        data = dataset.valid_data
    number = len(data)
    batch_size = args.batch_size
    num_batch = number // batch_size + int(number % batch_size > 0)
    metric = {'mr': 0.0, 'mrr': 0.0, 'hit@1': 0.0, 'hit@3': 0.0, 'hit@10': 0.0}

    t0 = time.time()
    model.eval()
    with torch.no_grad():
        for i in range(num_batch):
            start = i * batch_size
            end = min((i + 1) * batch_size, number)
            batch_data = data[start:end]

            heads = torch.LongTensor(batch_data[:, 0]).to(device)
            relations = torch.LongTensor(batch_data[:, 1]).to(device)
            tails = torch.LongTensor(batch_data[:, 2]).to(device)
            scores = model(heads, relations, tails)[0]
            scores = scores.detach().cpu().numpy()

            for j in range(batch_data.shape[0]):
                target = scores[j, batch_data[j][2]]
                scores[j, dataset.hr_vocab[(batch_data[j][0], batch_data[j][1])]] = -1e8
                rank = np.sum(scores[j] >= target) + 1
                metric['mr'] += rank
                metric['mrr'] += (1.0 / rank)
                if rank == 1:
                    metric['hit@1'] += 1.0
                if rank <= 3:
                    metric['hit@3'] += 1.0
                if rank <= 10:
                    metric['hit@10'] += 1.0

    metric['mr'] /= number
    metric['mrr'] /= number
    metric['hit@1'] /= number
    metric['hit@3'] /= number
    metric['hit@10'] /= number
    mr = metric['mr']
    mrr = metric['mrr']
    hit1 = metric['hit@1']
    hit3 = metric['hit@3']
    hit10 = metric['hit@10']
    t1 = time.time()
    print(f'[test: epoch {epoch}], mrr: {mrr}, mr: {mr}, hit@1: {hit1}, hit@3: {hit3}, hit@10: {hit10}, time: {t1 - t0}s')
    return metric


def get_save_dir(args):
    if args.save_dir:
        save_dir = args.save_dir
    else:
        name = str(datetime.datetime.now())[:-7].replace(' ', '-').replace(':', '-')
        save_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'save', name)
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    return save_dir


def save(save_dir, epoch, model):
    state = {
        'epoch': epoch,
        'model': model.state_dict()
    }
    state_path = os.path.join(save_dir, 'epoch_best.pth')
    torch.save(state, state_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Knowledge Graph Completion by Intermediate Variables Regularization')
    parser.add_argument('--mode', type=str, default='train',
                        choices=['train', 'test'],
                        help='mode')
    parser.add_argument('--device', type=str, default='cuda:0',
                        choices=['cuda:0', 'cpu'],
                        help='device')
    parser.add_argument('--dataset', type=str, default='WN18RR',
                        choices=['WN18RR', 'FB15k-237', 'YAGO3-10', 'Kinship'],
                        help='dataset')
    parser.add_argument('--model_name', type=str, default='CP',
                        choices=['CP', 'ComplEx', 'SimplE', 'ANALOGY', 'QuatE', 'TuckER'],
                        help='dataset')
    parser.add_argument('--regularization', type=str, default='TNRR',
                        choices=['w/o', 'F2', 'N3', 'DURA', 'TNRR'],
                        help='dataset')
    parser.add_argument('--save_dir', type=str, default='',
                        help='save directory')
    parser.add_argument('--test_interval', type=int, default=1,
                        help='number of epochs to test')

    parser.add_argument('--part', type=int, default=1,
                        help='part')
    parser.add_argument('--dimension', type=int, default=2000,
                        help='the dimension of each part')

    parser.add_argument('--epochs', type=int, default=200,
                        help='number of epochs to train')
    parser.add_argument('--batch_size', type=int, default=100,
                        help='batch size')
    parser.add_argument('--lr', type=float, default=0.1,
                        help='learning rate')
    parser.add_argument('--alpha', type=float, default=3,
                        help='power')
    parser.add_argument('--lambda1', type=float, default=0.0,
                        help='the coefficient of regularization')
    parser.add_argument('--lambda2', type=float, default=0.0,
                        help='the coefficient of regularization')
    parser.add_argument('--lambda3', type=float, default=0.0,
                        help='the coefficient of regularization')
    parser.add_argument('--lambda4', type=float, default=0.0,
                        help='the coefficient of regularization')

    parser.add_argument('--seed', type=int, default=0,
                        help='seed')
    parse_args = parser.parse_args()
    random.seed(parse_args.seed)
    np.random.seed(parse_args.seed)
    torch.manual_seed(parse_args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(parse_args.seed)

    main(parse_args)
