import os
import logging
import numpy as np
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import logging 

from data import BatchType, TestDataset
import itertools
from functools import reduce
from collections import defaultdict as ddict
import pickle

from TensorLog2 import TensorLog
from tqdm import tqdm

def get_confi_for_freq_random(data_reader, args, top=10):
    # tensorlog
    num_ent = len(data_reader.entity_dict)
    num_rel = len(data_reader.relation_dict)
    trps_has_inv = data_reader.add_inv(data_reader.train_data)
    # tlog = TensorLog(num_ent, num_rel, trps_has_inv)
    tlog_op = dict()
    r_headtail = ddict(list)
    for h, r, t in trps_has_inv:
        r_headtail[r].append([h, t])
    for r in range(num_rel * 2):
        r_op = torch.sparse_coo_tensor(torch.tensor(r_headtail[r]).t(),
                                       torch.ones(len(r_headtail[r])), (num_ent, num_ent)).cuda()
        tlog_op[r] = r_op

    all_ent_query = torch.ones(num_ent).reshape(1, -1).cuda()

    # read rule pkl
    rule_pkl = pickle.load(open(args.rule_pkl, 'rb'))
    rule_dict = rule_pkl['rule_dict']
    weight_dict = rule_pkl['weight_dict']

    # read frequency for r
    r_head_freq = torch.zeros((num_rel * 2, num_ent), dtype=torch.int)
    r_tail_freq = torch.zeros((num_rel * 2, num_ent), dtype=torch.int)
    for tri in trps_has_inv:
        h, r, t = tri
        r_head_freq[r][h] += 1
        r_tail_freq[r][t] += 1

    rel_rel_freq_body = torch.mm(r_tail_freq, r_head_freq.t())
    rel_rel_freq_body = rel_rel_freq_body / torch.sum(rel_rel_freq_body, dim=-1).reshape(-1, 1)

    rel_rel_freq_head = torch.mm(r_head_freq, r_head_freq.t())
    rel_rel_freq_head = rel_rel_freq_head / torch.sum(rel_rel_freq_head, dim=-1).reshape(-1, 1)

    rule_dict = dict()
    weight_dict = dict()
    for r in tqdm(range(num_rel)):
        sort_idx = np.argsort(-rel_rel_freq_head[r])
        curr_step_rules = [[[si], rel_rel_freq_head[r][si]] for si in sort_idx[:5]]
        next_step_rules = []
        for step in range(args.rule_length-1):
            for r_w in curr_step_rules:
                last_r = r_w[0][-1]
                sort_idx = np.argsort(-rel_rel_freq_body[last_r])
                next_append_rules = [[r_w[0]+[si],
                                      r_w[1]*rel_rel_freq_body[last_r][si]] for si in sort_idx[:5]]
                next_step_rules.extend(next_append_rules)
            curr_step_rules = next_step_rules
            next_step_rules = []

        last_step_rules = torch.tensor([r_w[0] for r_w in curr_step_rules])
        last_step_weights = torch.tensor([r_w[1] for r_w in curr_step_rules])

        # re sort
        curr_weights, sort_idx = torch.sort(last_step_weights, descending=True)
        curr_weights = last_step_weights[:top]
        curr_rules = last_step_rules[sort_idx[:top]]

        rule_dict[r] = curr_rules.tolist()
        weight_dict[r] = curr_weights.tolist()

    confi_dict = ddict(dict)
    for r in tqdm(range(num_rel)):
        # calculate confidence for rules
        # tensorlog head
        head_res = torch.sparse.mm(torch.transpose(tlog_op[r], 0, 1), all_ent_query.t()).t()
        head_res.squeeze_()

        confi_list = []

        for curr_rule in rule_dict[r]:
            # tensorlog body
            query = all_ent_query
            for curr_rel in curr_rule:
                res = torch.sparse.mm(torch.transpose(tlog_op[curr_rel], 0, 1), query.t()).t()
                query = res
            body_res = res
            body_res.squeeze_()

            nz_idx = torch.nonzero(body_res)
            body_res[nz_idx] = 1

            if torch.sum(body_res) == 0:
                confi = 0
            else:
                confi = len(torch.nonzero(head_res * body_res)) / len(torch.nonzero(body_res))

            confi_list.append(confi)

        confi_dict[r] = confi_list

    for top in [1, 5, 10]:
        mean_confi = 0
        for r in range(num_rel):
            mean_confi += np.mean(confi_dict[r][:top])
        mean_confi /= num_rel
        print(f'top{top}: {mean_confi}')


def get_confi_for_all_random(data_reader, args):
    # tensorlog
    num_ent = len(data_reader.entity_dict)
    num_rel = len(data_reader.relation_dict)
    trps_has_inv = data_reader.add_inv(data_reader.train_data)

    tlog_op = dict()
    r_headtail = ddict(list)
    for h, r, t in trps_has_inv:
        r_headtail[r].append([h, t])
    for r in range(num_rel * 2):
        r_op = torch.sparse_coo_tensor(torch.tensor(r_headtail[r]).t(),
                                       torch.ones(len(r_headtail[r])), (num_ent, num_ent)).cuda()
        tlog_op[r] = r_op

    all_ent_query = torch.ones(num_ent).reshape(1, -1).cuda()

    if args.rule_length == 2:
        all_rules = list(itertools.product(list(range(num_rel * 2)), list(range(num_rel * 2))))
    elif args.rule_length == 3:
        all_rules = list(itertools.product(list(range(num_rel * 2)),
                                           list(range(num_rel * 2)),
                                           list(range(num_rel * 2))))

    all_body_res = []
    all_body_nz_cnt = []
    for curr_rule in tqdm(all_rules):
        # tensorlog body
        query = all_ent_query
        for curr_rel in curr_rule:
            res = torch.sparse.mm(torch.transpose(tlog_op[curr_rel], 0, 1), query.t()).t()
            query = res
        body_res = res
        body_res.squeeze_()

        nz_idx = torch.nonzero(body_res)
        body_res[nz_idx] = 1
        all_body_res.append(body_res.cpu())
        all_body_nz_cnt.append(len(torch.nonzero(body_res)))

    all_body_stack = torch.stack(all_body_res)

    confi_dict = dict()
    for r in tqdm(range(num_rel)):
        # tensorlog head
        head_res = torch.sparse.mm(torch.transpose(tlog_op[r], 0, 1), all_ent_query.t()).t()
        head_res = head_res.squeeze().cpu()

        head_res = head_res.reshape(1, -1)
        head_body_res_mul = head_res * all_body_stack

        confi_list = []
        for body_idx in range(len(all_body_res)):
            # body_res = body_res.cuda()
            if all_body_nz_cnt[body_idx] == 0:
                confi = 0
            else:
                confi = len(torch.nonzero(head_body_res_mul[body_idx])) / all_body_nz_cnt[body_idx]
            confi_list.append(confi)

        confi_dict[r] = confi_list

    mean_confi = 0
    for r in range(num_rel):
        mean_confi += np.mean(confi_dict[r])
    mean_confi /= num_rel
    print(mean_confi)


class KGExplainer(nn.Module):
    def __init__(self, kge, rulelearner, args):
        super(KGExplainer, self).__init__()
        self.kge = kge 
        self.rulelearner = rulelearner
        self.args = args 

    def forward(self, sample, batch_type=BatchType.SINGLE):
        head, relation, tail, rel_idx = self.kge.sample2embedding(sample, batch_type)
        rel_idx.cuda()

        # score from kge 
        score_kge = self.kge.func(head, relation, tail, batch_type)

        # score from rulelearner 
        score_rule = self.rulelearner.func(head, relation, tail, batch_type, rel_idx, self.kge,
                                           self.args.train_part)
        

        logging.debug(f'score_kge:{score_kge.size()}')
        logging.debug(f'score_rule:{score_rule.size()}')
        return score_kge, score_rule 

    @staticmethod
    def train_step(model, optimizer, train_iterator, args):
        '''
        A single train step. Apply back-propation and return the loss
        '''
        model.train()

        optimizer.zero_grad()

        positive_sample, negative_sample, subsampling_weight, batch_type = next(train_iterator)

        positive_sample = positive_sample.cuda()
        negative_sample = negative_sample.cuda()
        subsampling_weight = subsampling_weight.cuda()

        positive_score_kge_ori, positive_score_rule_ori = model(positive_sample)
        
        # negative score 
        negative_score_kge_ori, negative_score_rule_ori = model((positive_sample, negative_sample), batch_type=batch_type)

        """
        Aggregate negative scores from kge
        """
        negative_score_kge = (F.softmax(negative_score_kge_ori * args.adversarial_temperature, dim=1).detach()
                          * F.logsigmoid(-negative_score_kge_ori)).sum(dim=1)
        positive_score_kge = F.logsigmoid(positive_score_kge_ori).squeeze(dim=1)

        positive_sample_loss_kge = - (subsampling_weight * positive_score_kge).sum() / subsampling_weight.sum()
        negative_sample_loss_kge = - (subsampling_weight * negative_score_kge).sum() / subsampling_weight.sum()
        loss_kge = (positive_sample_loss_kge + negative_sample_loss_kge) / 2

        """
        Calculate loss 
        1. loss for ruleleaner
        2. loss for distillation
        """
        # loss for rule 
        negative_score_rule = (F.softmax(negative_score_rule_ori * args.adversarial_temperature, dim=1).detach()
                          * F.logsigmoid(-negative_score_rule_ori)).sum(dim=1)
        positive_score_rule = F.logsigmoid(positive_score_rule_ori).squeeze(dim=1)

        positive_sample_loss_rule = - (subsampling_weight * positive_score_rule).sum() / subsampling_weight.sum()
        negative_sample_loss_rule = - (subsampling_weight * negative_score_rule).sum() / subsampling_weight.sum()

        loss_rule = (positive_sample_loss_rule + negative_sample_loss_rule) / 2


        logging.debug(f'negative_score_kge: {negative_score_kge.size()}')
        logging.debug(f'negative_score_rule:{negative_score_rule.size()}')
        logging.debug(f'positive_score_kge:{positive_score_kge.size()}')
        logging.debug(f'positive_score_rule: {positive_score_rule.size()}')


        # loss for distill 
        loss_distill = torch.abs(torch.sigmoid(positive_score_kge_ori) - torch.sigmoid(positive_score_rule_ori)).mean() \
                       + torch.abs(torch.sigmoid(negative_score_kge_ori) - torch.sigmoid(negative_score_rule_ori)).mean()

        if args.train_part == 'rule':
            loss = loss_rule + loss_distill
        elif args.train_part == 'kge':
            loss = loss_kge + loss_rule

        loss.backward()

        optimizer.step()

        if args.train_part == 'rule':
            log = {
                'positive_sample_loss_kge': positive_sample_loss_kge.item(),
                'negative_sample_loss_kge': negative_sample_loss_kge.item(),
                'positive_sample_loss_rule': positive_sample_loss_rule.item(),
                'negative_sample_loss_rule': negative_sample_loss_rule.item(),
                'distill_loss':loss_distill.item(),
                'loss': loss.item()
            }
        elif args.train_part == 'kge':
            log = {
                'positive_sample_loss_kge': positive_sample_loss_kge.item(),
                'negative_sample_loss_kge': negative_sample_loss_kge.item(),
                'positive_sample_loss_rule': positive_sample_loss_rule.item(),
                'negative_sample_loss_rule': negative_sample_loss_rule.item(),
                'loss': loss.item()
            }

        return log 

    @staticmethod
    def test_step(model, data_reader, mode, args, target='explainer'):
        model.eval()

        test_dataloader_tail = DataLoader(
            TestDataset(
                data_reader,
                mode,
                BatchType.TAIL_BATCH
            ),
            batch_size=args.test_batch_size,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TestDataset.collate_fn
        )
        
        test_dataloader_head = DataLoader(
            TestDataset(
                data_reader,
                mode,
                BatchType.HEAD_BATCH
            ),
            batch_size=args.test_batch_size,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TestDataset.collate_fn
        )

        test_dataset_list = [test_dataloader_head, test_dataloader_tail]

        logs = []

        step = 0
        total_steps = sum([len(dataset) for dataset in test_dataset_list])

        with torch.no_grad():
            for test_dataset in test_dataset_list:
                for positive_sample, negative_sample, filter_bias, batch_type in test_dataset:
                    positive_sample = positive_sample.cuda()
                    negative_sample = negative_sample.cuda()
                    filter_bias = filter_bias.cuda()

                    batch_size = positive_sample.size(0)
                    
                    if target == 'kge':
                        score, _ = model((positive_sample, negative_sample), batch_type)
                    elif target == 'explainer':
                        _, score = model((positive_sample, negative_sample), batch_type)

                    score += filter_bias

                    # Explicitly sort all the entities to ensure that there is no test exposure bias
                    argsort = torch.argsort(score, dim=1, descending=True)

                    # here only TAIL_BATCH will be used
                    if batch_type == BatchType.HEAD_BATCH:
                        positive_arg = positive_sample[:, 0]
                    elif batch_type == BatchType.TAIL_BATCH:
                        positive_arg = positive_sample[:, 2]
                    else:
                        raise ValueError('mode %s not supported' % mode)
                    
                    for i in range(batch_size):
                        # Notice that argsort is not ranking
                        ranking = (argsort[i, :] == positive_arg[i]).nonzero()
                        assert ranking.size(0) == 1

                        # ranking + 1 is the true ranking used in evaluation metrics
                        ranking = 1 + ranking.item()
                        logs.append({
                            'MRR': 1.0 / ranking,
                            'MR': float(ranking),
                            'HITS@1': 1.0 if ranking <= 1 else 0.0,
                            'HITS@3': 1.0 if ranking <= 3 else 0.0,
                            'HITS@10': 1.0 if ranking <= 10 else 0.0,
                        })

                    if step % args.test_log_steps == 0:
                        logging.info('Evaluating the model... ({}/{})'.format(step, total_steps))

                    step += 1
        
        metrics = {}
        for metric in logs[0].keys():
            metrics[metric] = sum([log[metric] for log in logs]) / len(logs)

        return metrics


    @staticmethod
    def explain(model, data_reader, mode, args, top=10):
        hr2t, trinv2h = data_reader.dict_for_explain(data_reader.train_data)

        # tensorlog
        num_ent = len(data_reader.entity_dict)
        num_rel = len(data_reader.relation_dict)
        # trps_has_inv = data_reader.add_inv(data_reader.train_data)
        trps_has_inv_and_self = data_reader.add_inv_and_self(data_reader.train_data)
        # tlog = TensorLog(num_ent, num_rel, trps_has_inv)
        tlog_op = dict()
        r_headtail = ddict(list)
        for h, r, t in trps_has_inv_and_self:
            r_headtail[r].append([h, t])
        for r in range(len(r_headtail)):
            r_op = torch.sparse_coo_tensor(torch.tensor(r_headtail[r]).t(),
                                           torch.ones(len(r_headtail[r])), (num_ent, num_ent)).cuda()
            tlog_op[r] = r_op

        all_ent_query = torch.ones(num_ent).reshape(1, -1).cuda()

        def apply_rule(h, rule_body, hr2t, trinv2h):
            start = [h]
            end = []
            for rel in rule_body:
                for ent in start:
                    if (ent, rel) in hr2t.keys():
                        end.extend(hr2t[(ent, rel)])
                    elif (ent, rel) in trinv2h.keys():
                        end.extend(trinv2h[(ent, rel)])
                start = end
                end = []
            return start

        model.eval()

        head_r, outputs, _, ent_h, ent_t = model.rulelearner.ruledecoder(model.rulelearner.rule_embedding)
        rule_rel_emb = outputs
        model.kge.add_inverse_relation()
        model.kge.add_self_loop()
        rel_emb = model.kge.relation_embedding_with_inverse_and_self

        sim = model.kge.rel_sim_for_explain(rule_rel_emb, rel_emb)

        rule_dict = ddict(dict)
        weight_dict = ddict(dict)
        confi_dict = ddict(dict)
        for r in range(model.kge.num_relation):
            r_rule_wight = sim[r]
            top_cand_weight, top_cand_rel = torch.sort(r_rule_wight, dim=-1, descending=True)
            top_cand_weight = top_cand_weight[:, :, :10]  # 先在每个位置取足够多的candidate，然后组合后再排序
            top_cand_rel = top_cand_rel[:, :, :10]
            for n in range(model.args.num_rules):
                curr_cand_rel = top_cand_rel[n]
                curr_rules = torch.tensor(list(itertools.product(*curr_cand_rel.tolist())))

                curr_cand_weight = top_cand_weight[n]
                curr_weights_tuple = list(itertools.product(*curr_cand_weight.tolist()))
                curr_weights = torch.tensor([reduce(lambda x, y: x*y, w) for w in curr_weights_tuple])

                # re sort
                curr_weights, sort_idx = torch.sort(curr_weights, descending=True)
                curr_weights = curr_weights[:top]
                curr_rules = curr_rules[sort_idx[:top]]

                # calculate confidence for rules
                # tensorlog head
                head_res = torch.sparse.mm(torch.transpose(tlog_op[r], 0, 1), all_ent_query.t()).t()
                head_res.squeeze_()

                confi_list = []
                for curr_rule in curr_rules.tolist():
                    # tensorlog body
                    query = all_ent_query
                    for curr_rel in curr_rule:
                        res = torch.sparse.mm(torch.transpose(tlog_op[curr_rel], 0, 1), query.t()).t()
                        query = res
                    body_res = res
                    body_res.squeeze_()

                    nz_idx = torch.nonzero(body_res)
                    body_res[nz_idx] = 1

                    if torch.sum(body_res) == 0:
                        confi = 0
                    else:
                        confi = len(torch.nonzero(head_res*body_res)) / len(torch.nonzero(body_res))

                    confi_list.append(confi)

                rule_dict[r][n] = curr_rules
                weight_dict[r][n] = curr_weights
                confi_dict[r][n] = confi_list

        for top in [1, 5, 10]:
            mean_confi = 0
            for r in range(num_rel):
                mean_confi += np.mean(confi_dict[r][0][:top])
            mean_confi /= num_rel
            print(f'top{top}: {mean_confi}')
