import torch
import torch.nn as nn
import numpy as np

from .HEXgraph import HEXGraph

def local_normalization(t, method="log", rounding=True, states_dim=1):
    if method=="log":
        # max_log = torch.log(torch.max(t, dim=1))
        # min_log = torch.log(torch.min(t, dim=1))
        max_log, _ = torch.max(torch.log(t), dim=states_dim)
        min_log, _ = torch.min(torch.log(t), dim=states_dim)
        log_max_float = np.floor(np.log(torch.finfo(torch.float64).max))
        thresh = max_log.sub(log_max_float)
        avg = torch.add(max_log, min_log).div(2)
        log_n = torch.maximum(avg, thresh)
        n = torch.exp(log_n)
    elif method=="mean":
        n = torch.mean(t, dim=1)
        log_n = torch.log(n)
    else:
        return t, torch.zeros(t.shape[1:])

    if rounding:
        t_n = torch.where(torch.div(t, n)!=0, torch.div(t, n), torch.finfo(torch.float64).tiny)
    else:
        # t_n = torch.div(t, n)
        t_n = torch.where(n > 0, torch.div(t, n), t)

    return t_n, log_n


class fastHEXLayer(nn.Module):
    """ Custom HEX layer """
    def __init__(self, hexg, beta=0, loss_normalize=True, device=None, verbose=False):
        super().__init__()
        assert(isinstance(hexg, HEXGraph))
        # assert(hexg.checkConsistency())

        self.hexg=hexg
        self.device=device
        self.beta=beta
        self.loss_normalize=loss_normalize
        self.verbose = verbose
        
        self.sp_messages=None
        self.log_parition=None
        self.mp_messages=None
        self.viterbi_forward=None
        self.viterbi_backward=None
        self.viterbi_maxstate=None

        self.cliques=None
        self.numC=None
        
        self.statesSpace=None
        self.sp=None

        self.initHEXgraph()
        # self.initMessages()

    def initHEXgraph(self):
        self.hexg.buildJT()
        self.cliques = self.hexg.jt.cliques
        self.numC = self.hexg.jt.numC
        self.statesSpace = self.hexg.listStatesSpace()
        self.sp = self.hexg.recordSumProduct()

    def set_device(self, device):
        self.device = device

    def initMessages(self, scores=None):
        self.batch_size = scores.shape[0]

        # Torch tensors for autodiff and GPU computing
        self.sp_messages = {}
        self.log_partition = {}
        self.mp_messages = {}
        self.transition_matrices = {}

        # Numpy array for the Viterbi algorithm (no need of autodiff or GPU computing)
        self.viterbi_forward = {}
        self.viterbi_backward = {}
        self.viterbi_maxstate = np.zeros((self.hexg.numV, self.batch_size), dtype=bool)
        

        for i in range(self.numC):
            # Assign potentials to diagonal messages
            states = torch.tensor(self.statesSpace[i], dtype=torch.float64).to(self.device)
            # local_scores = scores[:, np.ix_(self.cliques[i])]
            local_scores = torch.index_select(scores, 1, torch.tensor(self.cliques[i]).to(self.device))

            # each variable score must only contribute once, here we only keep the scores for the variables assigned to that clique (according to jt.assignedVar)
            assigned = torch.tensor([(v in self.hexg.jt.assignedVar[i]) for v in self.cliques[i]], dtype=torch.uint8)
            assigned_scores = torch.mul(local_scores, assigned.to(self.device)).to(dtype=torch.float64)
            assigned_scores = assigned_scores.transpose(0, 1)

            # normalization step of the potentials
            msg = torch.unsqueeze(torch.exp(torch.matmul(states, assigned_scores)), dim=0)
            msg, log_norm = local_normalization(msg, method="log", rounding=True)
            # sum_msg = torch.sum(msg, dim=1)
            

            # self.sp_messages[i] = torch.div(msg, sum_msg)
            # self.log_partition[i] = torch.log(sum_msg)
            self.sp_messages[i] = msg.to(torch.float64)
            self.log_partition[i] = log_norm
            self.mp_messages[i] = self.sp_messages[i].detach().to(torch.float64)

            tm = []
            for sp in self.sp[i][-1]:
                line = torch.zeros(self.statesSpace[i].shape[0], device=self.device)
                line[sp.tolist()] = 1
                tm.append(line)

            # self.transition_matrices[i] = torch.tile(torch.unsqueeze(torch.stack(tm, dim=0), dim=-1), (1, 1, self.batch_size))
            self.transition_matrices[i] = torch.stack(tm, dim=0).to(torch.float64)

    def collectMessages(self, i, mp=True):

        if self.verbose:
            if torch.any(torch.isnan(self.sp_messages[i])):
                print("Messages to {} has NaNs".format(i))
        
        sp_prod = torch.prod(self.sp_messages[i], dim=0, keepdim=True)
        log_sum = torch.sum(self.log_partition[i], dim=0, keepdim=True)
        
        if mp:
            mp_prod = torch.prod(self.mp_messages[i], dim=0, keepdim=True)
        else:
            mp_prod=None

        if self.verbose:
            if torch.any(torch.isnan(sp_prod)):
                print(i)

        # if torch.isinf(torch.sum(sp_prod)) or torch.isnan(torch.sum(sp_prod)):
        #     print(i)

        return sp_prod, log_sum, mp_prod

    def distributeMessages(self, i, mp=True):
        sp_prod, log_sum, mp_prod = self.collectMessages(i, mp)

        j = self.hexg.jt.cliqParents[i]
        msg = torch.matmul(self.transition_matrices[i], sp_prod)
        # Try a different type of normalization
        msg, log_norm = local_normalization(msg, method="log", rounding=True)
        # sum_msg = torch.sum(msg, dim=1)
        # We let batches with 0 at 0 and normalize the rest
        # msg = torch.where(sum_msg > 0, torch.div(msg, sum_msg), msg)
        # msg = torch.div(msg, sum_msg)
        # if torch.any(sum_msg==0):
        #     print("Sum msg is null : ", i)
        self.sp_messages[j] = torch.cat([self.sp_messages[j], msg], dim=0)

        if mp:
            mp_msg, argmax_idx = torch.max(torch.mul(torch.unsqueeze(self.transition_matrices[i], dim=-1), mp_prod), dim=1, keepdim=True)
            # Try normalizing the mp_msg too
            mp_msg, _ = local_normalization(mp_msg.transpose(0, 1))
            # self.mp_messages[j] = torch.cat([self.mp_messages[j], mp_msg.transpose(0, 1)], dim=0)
            self.mp_messages[j] = torch.cat([self.mp_messages[j], mp_msg], dim=0)
            self.viterbi_forward[i] = argmax_idx.reshape(-1, self.batch_size).cpu().numpy()

        # compute the log partition message
        # log_msg = torch.add(log_sum, torch.log(sum_msg))
        log_msg = torch.add(log_sum, log_norm)
        self.log_partition[j] = torch.cat([self.log_partition[j], log_msg], dim=0)

    def decodeViterbi(self):

        for i in reversed(self.hexg.jt.order):
            cliqChildren = self.hexg.jt.cliqChildren[i]
            max_state = self.viterbi_backward[i]
            self.viterbi_maxstate[np.ix_(self.cliques[i])] = self.statesSpace[i][np.ix_(max_state)].transpose()
            
            for (j_ix, j) in enumerate(cliqChildren):
                max_state_idx = np.expand_dims(max_state, axis=0)
                self.viterbi_backward[j] = np.take_along_axis(self.viterbi_forward[j], max_state_idx, axis=0).reshape(self.batch_size).astype(np.int32)

        return self.viterbi_maxstate

    def forward(self, scores, mp=True):

        self.initMessages(scores)

        for i in range(self.numC-1):
            self.distributeMessages(self.hexg.jt.order[i], mp)

        sp_prod, log_sum, mp_prod = self.collectMessages(self.hexg.jt.order[self.numC-1], mp)

        sum_msg = torch.sum(sp_prod.squeeze(), dim=0)
        assert(torch.all(sum_msg))
        log_z = torch.add(torch.log(sum_msg), log_sum)


        if mp:
            _, max_state = torch.max(mp_prod.squeeze(), dim=0, keepdim=True)
            self.viterbi_backward[self.hexg.jt.order[self.numC-1]] = max_state.reshape(self.batch_size).cpu().numpy()

        # if torch.any(torch.isnan(log_z)).item():
        #     print("Last messsage sum nan : ", torch.any(torch.isnan(torch.log(sum_msg))).item())
        #     print("Log sum nan : ", torch.any(torch.isnan(log_sum)).item())

        return log_z

    def computeStateLogLikelyhood(self, state, scores):
        state_energy = torch.matmul(state.to(dtype=torch.float64), scores.transpose(0, 1))
        log_z, _ = self.forward(scores, mp=False)

        return torch.sub(state_energy, log_z)

    def loss(self, state, scores, mp=False, beta=None, log_z=None, normalize=None):

        if normalize is None:
            if self.loss_normalize:
                n=self.hexg.numV
        else:
            if normalize:
                n=self.hexg.numV
            else:
                n=1

        if beta is None:
            beta = self.beta

        # compute the energies (unnormalized probabilities) of each label state
        energies = torch.sum(torch.mul(state.to(torch.float64), scores), dim=1, keepdim=True).transpose(0, 1)

        # compute the base (or full) partition function (ie. the sum of energies of all states)
        base_log_z = torch.sum(torch.log(torch.exp(scores).add(1)), dim=1, keepdim=True).transpose(0, 1)

        # compute the base loss
        base_loss = torch.mean(torch.sub(base_log_z, energies)).div(n)

        # If beta is zero we simply compute the standard binary cross-entropy
        if beta==0:
            if mp:
                _ = self.forward(scores, mp=mp)

            return base_loss, base_loss, torch.zeros_like(base_log_z)

        else:
            # if not provided, compute the hex partition function (ie. the sum of energies of all possible states)
            if log_z is None:
                log_z = self.forward(scores, mp=mp)

            # if beta different than zero, compute the loss with both the base partition and the hex partition
            hex_loss = torch.mean(torch.sub(torch.add(base_log_z.mul(1+beta), log_z.mul(-beta)), energies)).div(n)

            log_diff = torch.sub(log_z, base_log_z)

            # check if the loss is nan
            if self.verbose:
                if torch.isnan(hex_loss).item():
                    print("State energy nan : ", torch.any(torch.isnan(energies)).item())
                    print("Base log partition nan : ", torch.any(torch.isnan(base_log_z)).item())
                    if beta!=0:
                        print("Hex log partition nan : ", torch.any(torch.isnan(log_z)).item())
        
            return hex_loss, base_loss, log_diff
