import os
import logging
import numpy as np
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from data import BatchType, TestDataset
import itertools

class LatentRuleLearner(nn.Module):
    def __init__(self, kge, num_rules, rule_length=5, num_layers=1):
        super(LatentRuleLearner, self).__init__()
        
        self.length = rule_length 
        self.rule_dim = kge.rel_dim * 2
        self.hidden_dim = kge.rel_dim
        ent_dim = kge.ent_dim 

        self.num_rules = kge.num_relation*num_rules

        self.rule_embedding = nn.Parameter(torch.zeros(kge.num_relation, num_rules, self.rule_dim))
        nn.init.xavier_uniform_(self.rule_embedding)

        self.ruledecoder = RuleDecoder(self.rule_embedding, self.hidden_dim, ent_dim, rule_length, num_layers)
        self.cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.cuda()

        self.rule_body_reg = None
        self.map2rel = None

    def func(self, head, rel, tail, batch_type, rel_idx, kge, train_part='rule'):
        num_select = 1
        num_samples = head.size()[0]
        logging.debug(f'head:{head.size()}') # (512, 1, 1000)
        logging.debug(f'rel:{rel.size()}') # (512,  1, 1500)
        logging.debug(f'tail:{tail.size()}') # (512, 1, 1000)
        
        """
        STEP1: decode rule sequences
        """
        # head_r: (num_rel, num_rule_per_rel, dim)
        # outputs: (num_rule, num_rule_per_rel, length, dim)
        # ent_h: (num_rule, num_rule_per_rel, ent_dim)
        # ent_t: (num_rule, num_rule_per_rel, ent_dim)
        head_r, outputs, _ , ent_h, ent_t = self.ruledecoder(self.rule_embedding)
        dim = head_r.size()[-1]

        if train_part == 'kge':
            if self.map2rel is None:
                sim = kge.rel_sim_for_explain(outputs, kge.relation_embedding)
                _, top_cand_rel = torch.sort(sim, dim=-1, descending=True)
                top_cand_rel = top_cand_rel[:, :, :, :3]

                top_cand_rel = top_cand_rel.squeeze(1)

                self.map2rel = torch.tensor([list(itertools.product(*cand_rel.tolist()))
                              for cand_rel in top_cand_rel]).to(top_cand_rel.device)

            outputs = kge.relation_embedding[self.map2rel]
        elif train_part == 'rule':
            self.map2rel = None

        """
        STEP2
        """
        rule_body = torch.index_select(
            outputs,
            dim=0,
            index = rel_idx
        )

        rule_h = torch.index_select(
            ent_h,
            dim=0,
            index = rel_idx
        )
        rule_t = torch.index_select(
            ent_t,
            dim=0,
            index = rel_idx 
        )

        if batch_type == BatchType.HEAD_BATCH or batch_type == BatchType.SINGLE:
            ent_in = tail 
            p_ent = kge.ent_similarity(ent_in, rule_t)
        elif batch_type == BatchType.TAIL_BATCH:
            ent_in = head 
            p_ent = kge.ent_similarity(ent_in, rule_h)
        else:
            raise NotImplementedError 
        # p_ent: (num_samples, 1, num_rule_per_rel)
        # p_ent: (num_samples, num_rule_per_rel)
        p_ent = torch.squeeze(p_ent, 1)
        # num_samples 
        argsort_o = torch.argsort(p_ent, dim=1, descending=True)[:, 0]
        argsort = torch.flatten(argsort_o)
        idx0 = torch.arange(argsort.size()[0])
        # rule_body: (num_sample, num_rule_per_rel, length, dim)
        rule_body = rule_body[idx0, argsort]
        rule_body = torch.unsqueeze(rule_body, 1)

        """
        STEP3: calculate p(a|G,q,z)
        1. making each sample (h,r,t) into (h,r1,r2,r3,t)
           there will be k sequence for each sample
        2. calculate sequence prediction via KGE model 
        """
        ent_dim = head.size()[-1]
        if batch_type == BatchType.TAIL_BATCH or batch_type == BatchType.SINGLE:
            head = head.view((num_samples, 1, ent_dim))
            head = head.expand(num_samples, num_select, ent_dim)
            # (num_samples * num_select, dim)
            # head = torch.flatten(head)
            head = kge.adapt_ent_emb(head)
            for i in range(rule_body.size()[-2]):
                logging.debug(f'=={i}th step of rules')
                rel = kge.adapt_rel_emb(rule_body[:, :, i, :])
                tail_p = kge.predict_t(head, rel)
                logging.debug(f'tail_p: phase {tail_p[0].size()}, mod {tail_p[1].size()}')
                head = tail_p 

            tail = kge.adapt_ent_emb(tail)
            rule_score = kge.distant_score(tail, tail_p)
        elif batch_type == BatchType.HEAD_BATCH:
            head_hb = tail 
            tail_hb = head 
            head_hb = head_hb.view((num_samples, 1, ent_dim))
            head_hb = head_hb.expand(num_samples, num_select, ent_dim)
            head_hb = kge.adapt_ent_emb(head_hb)
            for i in range(rule_body.size()[-2]):
                logging.debug(f'=={-(i+1)}th step of rules')
                rel_hb = rule_body[:, :, -(i+1), :]
                rel_hb = kge.inverse_relation(rel_hb)
                rel_hb = kge.adapt_rel_emb(rel_hb)
                tail_p = kge.predict_t(head_hb, rel_hb)
                logging.debug(f'tail_p: phase {tail_p[0].size()}, mod {tail_p[1].size()}')
                head_hb = tail_p 
            tail_hb = kge.adapt_ent_emb(tail_hb)
            rule_score = kge.distant_score(tail_hb, tail_p)
        else:
            raise NotImplementedError

        return rule_score


class RuleDecoder(nn.Module):
    """
    Must define
        `self.rule_embedding`
    """
    def __init__(self, rule_embedding, hidden_dim, ent_dim, rule_length=5, num_layers=1):
        """
        Input Parameters:
            - dim: hidden_dim for MLP in HeadDecoder and LSTM in BodyDecoder
            - rule_length: length of lantent rules 
            - num_layers: number of layers for LSTM
        """
        super(RuleDecoder, self).__init__()
        self.num_rel = rule_embedding.size()[0]
        self.num_rule_per_rel = rule_embedding.size()[1]
        self.rule_dim = rule_embedding.size()[-1]
        self.num_rules = self.num_rel * self.num_rule_per_rel


        self.hidden_dim = hidden_dim 
        self.ent_dim = ent_dim 
        # self.rule_dim_ent = rule_embedding_ent.size()[-1]
        # self.num_rules = rule_embedding.size()[0]
        self.length = rule_length

        self.rule_embedding = rule_embedding.view((-1, self.rule_dim))

        self.headdecoder = HeadDecoder(self.rule_dim , self.hidden_dim)
        self.bodydecoder = BodyDecoder(self.rule_dim, self.hidden_dim, rule_length, num_layers) 
        self.entitydecoder = EntityDecoder(self.rule_dim, self.ent_dim)     

    def forward(self, rule_embedding):
        
        rule_embedding_in = rule_embedding.view((self.num_rules, -1))

        head_r_decoded = self.headdecoder(rule_embedding_in)
        head_r = head_r_decoded.view(self.num_rules, 1, self.hidden_dim )
        head_r = head_r.expand(self.num_rules, self.length, self.hidden_dim)

        rule_embedding = rule_embedding_in.view(self.num_rules, 1, self.rule_dim )
        rule_embedding = rule_embedding.expand(self.num_rules, self.length, self.rule_dim )

        outputs, hidden = self.bodydecoder(rule_embedding)

        ent_h, ent_t = self.entitydecoder(rule_embedding_in)

        # head_r_decoded: (num_rule, dim)
        # outputs: (num_rule, length, dim)
        # ent: (num_rule, ent_dim)
        head_r_decoded = head_r_decoded.view((self.num_rel, self.num_rule_per_rel, -1))
        outputs = outputs.view((self.num_rel, self.num_rule_per_rel, self.length, -1))
        ent_h = ent_h.view((self.num_rel, self.num_rule_per_rel, -1))
        ent_t = ent_t.view((self.num_rel, self.num_rule_per_rel, -1))

        return head_r_decoded, outputs, hidden, ent_h, ent_t

class HeadDecoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(HeadDecoder, self).__init__()
        self.f1 = nn.Linear(input_dim, output_dim)
        self.f2 = nn.Linear(output_dim, output_dim)

    def forward(self, rule_embedding):
        """
        Given the rule embedding, output the head relation in each rule
        decoding function is:
            r_h = 2-MLP(rule_embedding)
        """
        # input: relu_embedding (_, dim)
        # output: r_h (_, dim)
        r_h = torch.relu(self.f1(rule_embedding))
        r_h = torch.relu(self.f2(r_h))
        return r_h

class EntityDecoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(EntityDecoder, self).__init__()
        self.f1_h = nn.Linear(input_dim, output_dim)
        self.f2_h = nn.Linear(output_dim, output_dim)

        self.f1_t = nn.Linear(input_dim, output_dim)
        self.f2_t = nn.Linear(output_dim, output_dim)

    def forward(self, rule_embedding):
        """
        Given the rule embedding, output the suitable head entity in each rule
        decoding function is: 
            h = 2-MLP(rule_embedding)
        """
        h = torch.relu(self.f1_h(rule_embedding))
        h = torch.relu(self.f2_h(h))

        t = torch.relu(self.f1_t(rule_embedding))
        t = torch.relu(self.f2_t(t))
        return h,t 


class BodyDecoder(nn.Module):
    def __init__(self, input_dim, output_dim, length, layers):
        super(BodyDecoder, self).__init__()
        # self.input_dim = input_dim+output_dim 
        self.input_dim = input_dim
        self.hidden_dim = output_dim 
        self.num_layers = layers
        self.length = length
        self.use_cuda = True

        self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers, batch_first=True)
        # self.linear = nn.Linear(self.hidden_dim, self.label_size)

    def zero_state(self, batch_size):
        state_shape = (self.num_layers, batch_size, self.hidden_dim)
        h0 = c0 = torch.zeros(*state_shape, requires_grad=False)
        if self.use_cuda:
            return (h0.cuda(), c0.cuda())
        else:
            return (h0, c0)

    def forward(self, rule_embedding, head_relation=None):
        """
        Given the rule embedding, and head relation, ouput the sequence of rule body 
        decoding function:
            LSTM(rule_embedding, head_relation)
        """
        hidden = self.zero_state(rule_embedding.size(0))
        embedding = rule_embedding
        outputs, hidden = self.rnn(embedding, hidden)

        return outputs, hidden