import torch
from model.mean_field_posterior import FactorizedPosterior
from model.gcn import GCN, TrainableEmbedding
from data_process.dataset import Dataset
from common.cmd_args import cmd_args
from tqdm import tqdm
import torch.optim as optim
from model.graph import KnowledgeGraph
from common.predicate import PRED_DICT
from common.utils import EarlyStopMonitor, get_lr, count_parameters
from common.evaluate import gen_eval_query
from itertools import chain
import random
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from os.path import join as joinpath
import os
import math
from collections import Counter
import time

def train(cmd_args):
    if not os.path.exists(cmd_args.exp_path):
        os.makedirs(cmd_args.exp_path)

    with open(joinpath(cmd_args.exp_path, 'options.txt'), 'w') as f:
        param_dict = vars(cmd_args)
        for param in param_dict:
            f.write(param + ' = ' + str(param_dict[param]) + '\n')

    logpath = joinpath(cmd_args.exp_path, 'eval.result')
    param_cnt_path = joinpath(cmd_args.exp_path, 'param_count.txt')

    # dataset and KG
    dataset = Dataset(cmd_args.data_root, cmd_args.batchsize,
                      cmd_args.shuffle_sampling)
    kg = KnowledgeGraph(dataset.fact_dict, PRED_DICT, dataset)

    # model
    if cmd_args.use_gcn == 1:
        gcn = GCN(kg, cmd_args.embedding_size - cmd_args.gcn_free_size, cmd_args.gcn_free_size,
                  num_hops=cmd_args.num_hops, num_layers=cmd_args.num_mlp_layers).to(cmd_args.device)
    else:
        gcn = TrainableEmbedding(kg, cmd_args.embedding_size).to(cmd_args.device)

    posterior_model = FactorizedPosterior(kg, cmd_args.embedding_size, cmd_args.slice_dim).to(cmd_args.device)

    if cmd_args.model_load_path is not None:
        gcn.load_state_dict(torch.load(joinpath(cmd_args.model_load_path, 'gcn.model')))
        posterior_model.load_state_dict(torch.load(joinpath(cmd_args.model_load_path, 'posterior.model')))

    # optimizers
    monitor = EarlyStopMonitor(cmd_args.patience)
    all_params = chain.from_iterable([posterior_model.parameters(), gcn.parameters()])
    optimizer = optim.Adam(all_params, lr=cmd_args.learning_rate, weight_decay=cmd_args.l2_coef)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=cmd_args.lr_decay_factor,
                                                     patience=cmd_args.lr_decay_patience, min_lr=cmd_args.lr_decay_min)
    start_time = time.time()
    # training,
    with open(param_cnt_path, 'w') as f:
        cnt_gcn_params = count_parameters(gcn)
        cnt_posterior_params = count_parameters(posterior_model)
        if cmd_args.use_gcn == 1:
            f.write('GCN params count: %d\n' % cnt_gcn_params)
        elif cmd_args.use_gcn == 0:
            f.write('plain params count: %d\n' % cnt_gcn_params)
        f.write('posterior params count: %d\n' % cnt_posterior_params)
        f.write('Total params count: %d\n' % (cnt_gcn_params + cnt_posterior_params))

    if cmd_args.no_train == 1:
        cmd_args.num_epochs = 0

    # for Freebase data
    # prepare data for M-step
    tqdm.write('preparing data for M-step...')
    pred_arg1_set_arg2 = dict()#用途：pred_arg1_set_arg2[pred][arg1] = set(arg2)，即pred(arg1, arg2)的所有arg2，pred是谓词名字，arg1是实体名字，arg2是实体名字，set(arg2)是一个集合，里面是arg2的名字， arg1是实体名字，arg2是实体名字
    pred_arg2_set_arg1 = dict()
    pred_fact_set = dict()#用途：pred_fact_set[pred] = set((arg1, arg2))，即pred(arg1, arg2)的所有事实，pred是谓词名字，(arg1, arg2)是一个元组，arg1是实体名字，arg2是实体名字，set((arg1, arg2))是一个集合，里面是(arg1, arg2)的元组， arg1是实体名字，arg2是实体名字
    for pred in dataset.fact_dict_2:#dataset.fact_dict_2: pred -> (arg1, arg2),pred是谓词名字
        pred_arg1_set_arg2[pred] = dict()
        pred_arg2_set_arg1[pred] = dict()
        pred_fact_set[pred] = set()
        for _, args in dataset.fact_dict_2[pred]:#args是一个元组，(arg1, arg2)
            if args[0] not in pred_arg1_set_arg2[pred]:
                pred_arg1_set_arg2[pred][args[0]] = set()
            if args[1] not in pred_arg2_set_arg1[pred]:
                pred_arg2_set_arg1[pred][args[1]] = set()
            pred_arg1_set_arg2[pred][args[0]].add(args[1])
            pred_arg2_set_arg1[pred][args[1]].add(args[0])
            pred_fact_set[pred].add(args)

    grounded_rules = []
    for rule_idx, rule in enumerate(dataset.rule_ls):#enumerate()函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列，同时列出数据和数据下标，rule_idx是下标，rule是数据
        grounded_rules.append(set())
        body_atoms = []
        head_atom = None
        for atom in rule.atom_ls:
            if atom.neg:
                body_atoms.append(atom)
            elif head_atom is None:
                head_atom = atom
        # atom in body must be observed
        assert len(body_atoms) <= 2
        if len(body_atoms) > 0:
            body1 = body_atoms[0]
            for _, body1_args in dataset.fact_dict_2[body1.pred_name]:
                var2arg = dict()
                var2arg[body1.var_name_ls[0]] = body1_args[0]
                var2arg[body1.var_name_ls[1]] = body1_args[1]
                for body2 in body_atoms[1:]:
                    if body2.var_name_ls[0] in var2arg:
                        if var2arg[body2.var_name_ls[0]] in pred_arg1_set_arg2[body2.pred_name]:
                            for body2_arg2 in pred_arg1_set_arg2[body2.pred_name][var2arg[body2.var_name_ls[0]]]:
                                var2arg[body2.var_name_ls[1]] = body2_arg2
                                grounded_rules[rule_idx].add(tuple(sorted(var2arg.items())))
                    elif body2.var_name_ls[1] in var2arg:
                        if var2arg[body2.var_name_ls[1]] in pred_arg2_set_arg1[body2.pred_name]:
                            for body2_arg1 in pred_arg2_set_arg1[body2.pred_name][var2arg[body2.var_name_ls[1]]]:
                                var2arg[body2.var_name_ls[0]] = body2_arg1
                                grounded_rules[rule_idx].add(tuple(sorted(var2arg.items())))

    # Collect head atoms derived by grounded formulas
    grounded_obs = dict()# 中文注释：grounded_obs是一个字典，key是谓词和参数，value是规则的索引
    grounded_hid = dict()
    grounded_hid_score = dict()
    cnt_hid = 0
    for rule_idx in range(len(dataset.rule_ls)):
        rule = dataset.rule_ls[rule_idx]
        for var2arg in grounded_rules[rule_idx]:
            var2arg = dict(var2arg)
            head_atom = rule.atom_ls[-1]
            assert not head_atom.neg  # head atom
            pred = head_atom.pred_name
            args = (var2arg[head_atom.var_name_ls[0]], var2arg[head_atom.var_name_ls[1]])
            if args in pred_fact_set[pred]:
                if (pred, args) not in grounded_obs:
                    grounded_obs[(pred, args)] = []
                grounded_obs[(pred, args)].append(rule_idx)
            else:
                if (pred, args) not in grounded_hid:
                    grounded_hid[(pred, args)] = []
                grounded_hid[(pred, args)].append(rule_idx)
    tqdm.write('observed: %d, hidden: %d' % (len(grounded_obs), len(grounded_hid)))

    # Aggregate atoms by predicates for fast inference
    #中文注释：pred_aggregated_obs是一个字典，key是谓词，value是参数，pred_aggregated_obs_args是一个字典，key是谓词，value是参数的列表
    # 用途：将谓词相同的参数聚合在一起，方便后面的计算
    pred_aggregated_hid = dict()
    pred_aggregated_hid_args = dict()
    for (pred, args) in grounded_hid:
        if pred not in pred_aggregated_hid:
            pred_aggregated_hid[pred] = []
        if pred not in pred_aggregated_hid_args:
            pred_aggregated_hid_args[pred] = []
        pred_aggregated_hid[pred].append((dataset.const2ind[args[0]], dataset.const2ind[args[1]]))
        pred_aggregated_hid_args[pred].append(args)
    pred_aggregated_hid_list = [[pred, pred_aggregated_hid[pred]] for pred in sorted(pred_aggregated_hid.keys())]

    for current_epoch in range(cmd_args.num_epochs):

        # E-step: optimize the parameters in the posterior model
        # num_batches = int(math.ceil(len(dataset.test_fact_ls) / cmd_args.batchsize))

        pbar = tqdm()
        acc_loss = 0.0
        cur_batch = 0

        for samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r in \
                dataset.get_batch_by_q(cmd_args.batchsize):

            node_embeds = gcn(dataset)

            loss = 0.0
            r_cnt = 0
            for ind, samples in enumerate(samples_by_r):
                neg_mask = neg_mask_by_r[ind]
                latent_mask = latent_mask_by_r[ind]
                obs_var = obs_var_by_r[ind]
                neg_var = neg_var_by_r[ind]

                if sum([len(e[1]) for e in neg_mask]) == 0:
                    continue

                potential, posterior_prob, obs_xent = posterior_model([samples, neg_mask, latent_mask,
                                                                       obs_var, neg_var],
                                                                      node_embeds, fast_mode=True)


                entropy = compute_entropy(posterior_prob)

                loss += - (potential.sum() * dataset.rule_ls[ind].weight + entropy) / (
                            potential.size(0) + 1e-6) + obs_xent

                r_cnt += 1

            if r_cnt > 0:
                loss /= r_cnt
                acc_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            pbar.update()
            cur_batch += 1
            pbar.set_description(
                'Epoch %d, train loss: %.4f, lr: %.4g' % (current_epoch, acc_loss / cur_batch, get_lr(optimizer)))

        # M-step: optimize the weights of logic rules
        with torch.no_grad():
            posterior_prob = posterior_model(pred_aggregated_hid_list, node_embeds, fast_inference_mode=True)
            for pred_i, (pred, var_ls) in enumerate(pred_aggregated_hid_list):
                for var_i, var in enumerate(var_ls):
                    args = pred_aggregated_hid_args[pred][var_i]
                    grounded_hid_score[(pred, args)] = posterior_prob[pred_i][var_i]# 中文注释：grounded_hid_score是一个字典，key是谓词和参数，value是后验概率

            rule_weight_gradient = torch.zeros(len(dataset.rule_ls)) # 中文注释：rule_weight_gradient是一个列表，里面存放的是规则的权重,zero()函数是将rule_weight_gradient的值全部置为0
            for (pred, args) in grounded_obs:
                for rule_idx in set(grounded_obs[(pred, args)]):
                    rule_weight_gradient[rule_idx] += 1.0 - compute_MB_proba(dataset.rule_ls, grounded_obs[(pred, args)])
            for (pred, args) in grounded_hid:
                for rule_idx in set(grounded_hid[(pred, args)]):# set() 函数创建一个无序不重复元素集,grounded_hid[(pred, args)]是一个列表，里面存放的是rule_idx,所以这里是对rule_idx去重
                    target = grounded_hid_score[(pred, args)]
                    compute_MB = compute_MB_proba(dataset.rule_ls, grounded_hid[(pred, args)])
                    a = (target - compute_MB).cpu()

                    rule_weight_gradient[rule_idx] += a


            for rule_idx, rule in enumerate(dataset.rule_ls):
                rule.weight += cmd_args.learning_rate_rule_weights * rule_weight_gradient[rule_idx]
                # print(dataset.rule_ls[rule_idx].weight, end=' ')

        pbar.close()

        # validation

        with torch.no_grad():
            node_embeds = gcn(dataset)

            valid_loss = 0.0
            cnt_batch = 0
            for samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r in \
                    dataset.get_batch_by_q(cmd_args.batchsize, validation=True):
                loss = 0.0
                r_cnt = 0
                for ind, samples in enumerate(samples_by_r):
                    neg_mask = neg_mask_by_r[ind]
                    latent_mask = latent_mask_by_r[ind]
                    obs_var = obs_var_by_r[ind]
                    neg_var = neg_var_by_r[ind]

                    if sum([len(e[1]) for e in neg_mask]) == 0:
                        continue

                    valid_potential, valid_prob, valid_obs_xent = posterior_model([samples, neg_mask, latent_mask,
                                                                                   obs_var, neg_var],
                                                                                  node_embeds, fast_mode=True)



                    valid_entropy = compute_entropy(valid_prob)

                    loss += - (valid_potential.sum() + valid_entropy) / (
                                valid_potential.size(0) + 1e-6) + valid_obs_xent

                    r_cnt += 1

                if r_cnt > 0:
                    loss /= r_cnt
                    valid_loss += loss.item()

                cnt_batch += 1

            tqdm.write('Epoch %d, valid loss: %.4f' % (current_epoch, valid_loss / cnt_batch))

            should_stop = monitor.update(valid_loss)
            scheduler.step(valid_loss)

            is_current_best = monitor.cnt == 0# 中文注释：monitor.cnt是一个计数器，如果是第一次，那么is_current_best就是True
            if is_current_best:
                savepath = joinpath(cmd_args.exp_path, 'saved_model')
                os.makedirs(savepath, exist_ok=True)
                torch.save(gcn.state_dict(), joinpath(savepath, 'gcn.model'))
                torch.save(posterior_model.state_dict(), joinpath(savepath, 'posterior.model'))

            should_stop = should_stop or (current_epoch + 1 == cmd_args.num_epochs)

            if should_stop:
                tqdm.write('Early stopping')
                break
    end_time = time.time()  # Record the end time
    training_time = end_time - start_time  # Calculate the training time
    training_time_hours = training_time / 3600
    tqdm.write('Total training time: %.2f seconds' % training_time)
    tqdm.write('Total training time: %.2f hours' % (training_time / 3600))
    with open(joinpath(cmd_args.exp_path,'training_time_log.txt'), 'w') as file:
        file.write(f"Total training time: {training_time:.2f} seconds\n")
        file.write(f"Total training time (hours): {training_time_hours:.2f} hours\n")

    # ======================= generate rank list =======================
    node_embeds = gcn(dataset)

    pbar = tqdm(total=len(dataset.test_fact_ls))
    pbar.write('*' * 10 + ' Evaluation ' + '*' * 10)
    rrank = 0.0
    hits10 = 0.0
    hits1 = 0.0
    hits3 = 0.0
    mr = 0.0
    cnt = 0

    rrank_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT])
    hits10_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT])
    hits1_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT])
    hits3_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT])
    cnt_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT])
    mr_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT])

    for pred_name, X, invX, sample in gen_eval_query(dataset, const2ind=kg.ent2idx):
        x_mat = np.array(X)
        invx_mat = np.array(invX)
        sample_mat = np.array(sample)

        tail_score, head_score, true_score = posterior_model([pred_name, x_mat, invx_mat, sample_mat], node_embeds)

        rank = torch.sum(tail_score >= true_score).item() + 1
        mr += rank
        rrank += 1.0 / rank
        hits10 += 1 if rank <= 10 else 0
        hits1 += 1 if rank <= 1 else 0
        hits3 += 1 if rank <= 3 else 0

        mr_pred[pred_name] += rank
        rrank_pred[pred_name] += 1.0 / rank
        hits10_pred[pred_name] += 1 if rank <= 10 else 0
        hits1_pred[pred_name] += 1 if rank <= 1 else 0
        hits3_pred[pred_name] += 1 if rank <= 3 else 0

        rank = torch.sum(head_score >= true_score).item() + 1
        mr += rank
        rrank += 1.0 / rank
        hits10 += 1 if rank <= 10 else 0
        hits1 += 1 if rank <= 1 else 0
        hits3 += 1 if rank <= 3 else 0

        mr_pred[pred_name] += rank
        rrank_pred[pred_name] += 1.0 / rank
        hits10_pred[pred_name] += 1 if rank <= 10 else 0
        hits1_pred[pred_name] += 1 if rank <= 1 else 0
        hits3_pred[pred_name] += 1 if rank <= 3 else 0

        cnt_pred[pred_name] += 2
        cnt += 2

        if cnt % 100 == 0:
            with open(logpath, 'w') as f:
                f.write('%i sample eval\n' % cnt)
                f.write('mr %.4f\n' % (mr / cnt))
                f.write('mrr %.4f\n' % (rrank / cnt))
                f.write('hits10 %.4f\n' % (hits10 / cnt))
                f.write('hits1 %.4f\n' % (hits1 / cnt))
                f.write('hits3 %.4f\n' % (hits3 / cnt))

                f.write('\n')
                for pred_name in PRED_DICT:
                    if cnt_pred[pred_name] == 0:
                        continue
                    f.write('mr %s %.4f\n' % (pred_name, mr_pred[pred_name] / cnt_pred[pred_name]))
                    f.write('mrr %s %.4f\n' % (pred_name, rrank_pred[pred_name] / cnt_pred[pred_name]))
                    f.write('hits10 %s %.4f\n' % (pred_name, hits10_pred[pred_name] / cnt_pred[pred_name]))
                    f.write('hits1 %s %.4f\n' % (pred_name, hits1_pred[pred_name] / cnt_pred[pred_name]))
                    f.write('hits3 %s %.4f\n' % (pred_name, hits3_pred[pred_name] / cnt_pred[pred_name]))

        pbar.update()

    with open(logpath, 'w') as f:
        f.write('complete\n')
        f.write('mr %.4f\n' % (mr / cnt))
        f.write('mrr %.4f\n' % (rrank / cnt))
        f.write('hits1 %.4f\n' % (hits1 / cnt))
        f.write('hits3 %.4f\n' % (hits3 / cnt))
        f.write('hits10 %.4f\n' % (hits10 / cnt))
        f.write('\n')

        tqdm.write('mr %.4f\n' % (mr / cnt))
        tqdm.write('mrr %.4f\n' % (rrank / cnt))
        tqdm.write('hits1 %.4f\n' % (hits1 / cnt))
        tqdm.write('hits3 %.4f\n' % (hits3 / cnt))
        tqdm.write('hits10 %.4f\n' % (hits10 / cnt))

        for pred_name in PRED_DICT:
            if cnt_pred[pred_name] == 0:
                continue
            f.write('mr %s %.4f\n' % (pred_name, mr_pred[pred_name] / cnt_pred[pred_name]))
            f.write('mrr %s %.4f\n' % (pred_name, rrank_pred[pred_name] / cnt_pred[pred_name]))
            f.write('hits1 %s %.4f\n' % (pred_name, hits1_pred[pred_name] / cnt_pred[pred_name]))
            f.write('hits3 %s %.4f\n' % (pred_name, hits3_pred[pred_name] / cnt_pred[pred_name]))
            f.write('hits10 %s %.4f\n' % (pred_name, hits10_pred[pred_name] / cnt_pred[pred_name]))

    os.system('mv %s %s' % (logpath, joinpath(cmd_args.exp_path,
                                              'performance_hits1_%.4f_hits3_%.4f_hits10_%.4f_mrr_%.4f_mr_%.4f.txt' %(
                                              (hits1 / cnt), (hits3 / cnt), (hits10 / cnt), (rrank / cnt), (mr /cnt)))))

    pbar.close()



def compute_entropy(posterior_prob):
    eps = 1e-6
    posterior_prob.clamp_(eps, 1 - eps)
    compl_prob = 1 - posterior_prob
    entropy = -(posterior_prob * torch.log(posterior_prob) + compl_prob * torch.log(compl_prob)).sum()
    return entropy
def compute_MB_proba(rule_ls, ls_rule_idx):
    rule_idx_cnt = Counter(ls_rule_idx)
    numerator = 0
    for rule_idx in rule_idx_cnt:
        weight = rule_ls[rule_idx].weight

        cnt = rule_idx_cnt[rule_idx]

        if cnt*weight > 700:
            c_w = 700
        else:
            c_w = weight * cnt
        # numerator += np.exp(weight * cnt)
        numerator += math.exp(c_w)
    return numerator / (numerator + 1.0)



if __name__ == '__main__':

    random.seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)

    train(cmd_args)
