import torch 
import torch.nn as nn
import torch.nn.functional as F

# from utils import norm_distribution,customize_lookup

import logging 

# partialy copied from numerical_rl
class TensorLog(nn.Module):
    def __init__(self, num_entity, num_relation, trps, norm=False, dropout=False):
        super(TensorLog, self).__init__()

        self.op_list = None
        self.num_entity = num_entity
        self.num_relation = num_relation
        self.norm = norm 
        self.dropout = dropout

        # rel_ops = self.triplet_to_op_list(trps, num_relation//2, num_entity)
        rel_ops = self.triplet_to_op_list(trps, num_relation, num_entity)
        self.set_ops(rel_ops)


    def set_ops(self, op_list):
        del self.op_list
        self.op_list = nn.ModuleList(op_list).cuda()

    def forward(self, queries, rels):
        """ Build a computation graph that represents the model 
        Inputs:
        - queries(tensor)(shape=bs*nent):a batch of probability distrubution of head entity query 
        - rels(tensor)(shape=bs): the correspond relation of each input query
        Outputs:
        - res(tensor)(shape=ba*nent): the query results of inputs 
        """
        res = []
        inverse_label = rels>=(self.num_relation//2)        
        # print(rels, inverse_label)
        for i in range(rels.size()[0]):
            r = rels[i]
            if inverse_label[i]:
                op_id = r-self.num_relation//2
            else:
                op_id = r
            # print(f'queries[i]:{torch.nonzero(queries[i])}')
            # print(f'inverse_label[i]:{ inverse_label[i]}')
            # print(f'op_id:{op_id}')
            res.append(self.op_list[op_id](queries[i], inverse_label[i]))
        res = torch.cat(res, 0)
        logging.debug(f'tensorlog projector results shape: {res.shape}')
        return res             
    
    @staticmethod
    def triplet_to_op_list(triplet, num_relation, num_entity):
        ops = [[] for i in range(num_relation)]
        for (e1, rel, e2) in triplet:
            ops[rel].append((e1, e2))
        # add identity relation indices
        # here we make zero relation as indentity to ensure the operator creation
        # and will overwrite the results during forward function 
        for i in range(num_entity):
            ops[num_relation-2].append((i,i))
            ops[num_relation-1].append((i,i))
        ret = []
        for (i,op) in enumerate(ops):
            if len(op) == 1:
                idx = torch.LongTensor(op).view(1, -1)
            else:
                idx = torch.LongTensor(op)
            ret.append(SparseOp(idx, num_entity,False))
        # ret.append(IdentityOp())
        return ret                       

class SparseOp(nn.Module):
    def __init__(self, indices, num_entity, to_negate=False):
        super(SparseOp, self).__init__()
        self.register_buffer('indices', indices)
        self.num_entity = num_entity
        # self.num_return = 4 if to_negate else 2
        # self.to_negate = to_negate

    def forward(self, memories, inverse_label):
        v = memories.view(1, -1)
        shape = (self.num_entity, self.num_entity)
        # A = torch.sparse.ByteTensor(self.indices.t(), torch.ones_like(self.indices[:,0],dtype=torch.float), shape).detach().cuda()
        try:
            A = torch.sparse.ByteTensor(self.indices.t(), torch.ones_like(self.indices[:,0],dtype=torch.float), shape).detach().cuda()
        except Exception:
            # print(('!!!!!!!!!!!!!!!!!!!!!! SparseOp unavailable'))
            return torch.zeros_like(v).cuda()
        # print(f'A.device:{A.device}')
        # print(f'v.devce:{v.device}')
        # v1, v2 = A.t().mm(v.t()).t(), A.mm(v.t()).t()
        if inverse_label:
            return A.mm(v.t()).t()
        else:
            return A.t().mm(v.t()).t()