import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def log1mexp(x):
    if not(torch.all(x >= 0)):
        print(x)
        raise Exception("Only positive negative log-probailities are accepted !")
    # assert(torch.all(x >= 0))
    return torch.where(x < 0.6931471805599453094, torch.log(-torch.expm1(-x)), torch.log1p(-torch.exp(-x)))

class ClassificationSystem(nn.Module):
    def __init__(self, model):
        super().__init__()

        """
        Base classification system

        Inputs:
            - model : the base neural network model
        """

        self.model = model

    def forward(self, x):
        return self.model(x)

    def bce(self, y, x=None, scores=None, normalize=True):
        if scores is None:
            if not(x is None):
                scores = self.model(x)
            else:
                raise Exception('No input given')

        energies = torch.sum(torch.mul(y.to(torch.float64), scores), dim=1, keepdim=True)
        log_z = torch.sum(torch.log(torch.exp(scores).add(1)), dim=1, keepdim=True)
        
        if normalize:
            n=y.shape[1]
            loss = torch.mean(torch.sub(log_z, energies)).div(n)

        else:
            loss = torch.mean(torch.sub(log_z, energies))

        return loss

    def loss(self, y, x=None, scores=None, normalize=True):
        return self.bce(y=y, x=x, scores=scores, normalize=normalize)

    def ibm(self, x=None, scores=None):
        if scores is None:
            if not(x is None):
                scores = self.model(x)
            else:
                raise Exception('No input given')

        return torch.gt(scores, 0)

    def predict(self, x=None, scores=None):
        return self.ibm(x=x, scores=scores)

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

class HEXClassificationSystem(ClassificationSystem):
    def __init__(self, model, hexL):
        super().__init__(model=model)

        """
        Neuro-symbolic classification system

        Inputs:
            - model : the base neural network model
            - hexL : the hex-layer for computing both the loss and the inference output
        """

        self.hexL = hexL

    def loss(self, y, x=None, scores=None, mp=True):
            if not(scores is None):
                hex_loss, base_loss, log_diff = self.hexL.loss(state=y, scores=scores, mp=mp)
                return hex_loss, base_loss, log_diff

            elif not(x is None):
                scores = self.model(x)
                hex_loss, base_loss, log_diff = self.hexL.loss(state=y, scores=scores, mp=mp)
                return hex_loss, base_loss, log_diff

            else:
                raise Exception('No input given')
    
    def predict(self, x=None, scores=None, mp=True):
            if mp:
                return torch.from_numpy(self.hexL.decodeViterbi().transpose())
            
            elif not(scores is None):
                self.hexL.forward(scores=scores, mp=True)
                return torch.from_numpy(self.hexL.decodeViterbi().transpose())

            elif not(x is None):
                scores = self.model(x)
                self.hexL.forward(scores=scores, mp=True)
                return torch.from_numpy(self.hexL.decodeViterbi().transpose())

            else:
                raise Exception('No input given')


class CircuitClassificationSystem(ClassificationSystem):
    def __init__(self, model, circuit, beta=0, num_reps=1):
        super().__init__(model=model)

        """
        Neuro-symbolic classification system

        Inputs:
            - model : the base neural network model
            - circuit : tractable circuit
        """

        self.circuit = circuit
        self.circuit.beta.num_reps = num_reps
        self.beta = beta
        self.bsz = None

    def loss(self, y, x=None, scores=None, mp=True):
        if scores is None:
            if x is None:
                raise Exception('No input given')
            else:
                scores = self.model(x)

        if mp:
            self.scores = scores
            self.bsz = scores.size()[0]
            self.circuit.beta.mixing = torch.ones((self.bsz, self.circuit.beta.num_reps)).to("cuda")

        y=y.to(torch.float64)
        scores=scores.to(torch.float64)

        n = scores.shape[1]

        # compute the energies (unnormalized probabilities) of each label y
        energies = torch.sum(torch.mul(y, scores), dim=1, keepdim=True).transpose(0, 1)

        # compute the base (or full) partition function (ie. the sum of energies of all states)
        base_log_z = torch.sum(torch.log(torch.exp(scores).add(1)), dim=1, keepdim=True).transpose(0, 1)

        # compute the base loss
        base_loss = torch.mean(torch.sub(base_log_z, energies)).div(n)

        BCE = torch.nn.BCEWithLogitsLoss(reduction='none')
        bce = BCE(scores, y).mean()

        # If beta is zero we simply compute the standard binary cross-entropy
        if self.beta==0:
            return bce, base_loss, torch.zeros_like(base_log_z)

        else:
            # compute the conditionned partition function (ie. the sum of energies of all possible states)
            logprobs = F.logsigmoid(scores).clamp(max=-1e-7)
            litweights = [[log1mexp(-lp), lp] for lp in logprobs.unbind(axis=1)]
            sl = -self.circuit.get_tf_ac(litweights, log_space=True)

            # if beta different than zero, compute the loss with both the base partition and the hex partition
            loss = bce + torch.mean(sl).mul(self.beta).div(n)
            log_diff = -torch.add(sl, base_log_z)
            return loss, base_loss, log_diff
    
    def predict(self, x=None, scores=None, mp=True):  
        if (scores is None):
            if (x is None):
                if mp:
                    # print(self.bsz)
                    mpe = self.circuit.get_mpe_inst(self.bsz)
                    return (mpe > 0).long()
                else:
                    raise Exception('No input given')
            else:
                scores = self.model(x)
        
        self.bsz = scores.size()[0]
        self.circuit.beta.mixing = torch.ones((self.bsz, self.circuit.beta.num_reps))
        logprobs = F.logsigmoid(scores).unsqueeze(-1)
        litweights = [[log1mexp(-lp), lp] for lp in logprobs.unbind(axis=1)]
        # print("litweights len : ", len(litweights))
        self.circuit.parameterize_ff(litweights)
        # print(self.bsz)
        mpe = self.circuit.get_mpe_inst(self.bsz)
        output=mpe.gt(0)
        # print(output)
        return output

import networkx as nx

def topological_sort_edges(G):
    if not(isinstance(G, nx.DiGraph)):
        raise Exception("Only implemented for networkx DiGraph !")
    if not(nx.is_directed_acyclic_graph(G)):
        raise Exception("Only implemented for Directed Acyclic Graphs !")

    order = []
    for v in nx.topological_sort(G):
        for e in G.out_edges(v):
            order.append(e)

    return order

def grid(n):
    G = nx.DiGraph()
    for i in range(n):
        for j in range(n):
            if i<n-1:
                G.add_edge((i, j), (i+1, j))
            if j<n-1:
                G.add_edge((i, j), (i, j+1))
    return G

class GridClassificationSystem(CircuitClassificationSystem):
    def __init__(self, model, n, circuit=None, beta=0, num_reps=1):
        super().__init__(model=model, circuit=circuit, beta=beta, num_reps=num_reps)
        self.n=n
        self.grid = grid(n)
        self.order = topological_sort_edges(self.grid)
        self.edges = {self.order[i]:i for i in range(len(self.order))}

    def predict(self, x=None, scores=None, mp=True):  
        if (scores is None):
            if not(x is None):
                scores = self.model(x)
            elif mp and not(self.scores is None):
                scores=self.scores
            else:
                raise Exception('No input given')

        paths = []

        for w in scores:
            nx.set_edge_attributes(self.grid, {self.order[i]:{"logits":-w[i]} for i in range(len(self.order))})
            trace = nx.bellman_ford_path(self.grid, (0,0), (self.n-1, self.n-1), weight="logits")
            path = torch.zeros(len(self.order))
            for i in range(len(trace)-1):
                path[self.edges[(trace[i], trace[i+1])]] = 1
            paths.append(path)

        paths = torch.stack(paths, axis=0)

        return paths


class MCClassificationSystem(ClassificationSystem):
    def __init__(self, model, beta=0):
        super().__init__(model=model)

        """
        Classification system that enforces mutual exclusion between the variables.

        Inputs:
            - model : the base neural network model
            - beta : the regularization coefficient, if beta = 0 : standard, if beta = -1 : semantic conditionning, if beta>0 : semantic regularization  
        """

        self.beta = beta
        self.scores=None

    def loss(self, y, x=None, scores=None, normalize=True, **kwargs):
        if scores is None:
            if not(x is None):
                # print("Shape of x : ", x.shape)
                self.scores = self.model(x)
            else:
                raise Exception('No input given')
        else:
            self.scores=scores

        # print("Shape of scores : ", self.scores.shape)
        # print("Shape of y : ", y.shape)
        # print("y : ", y)
        energies = torch.gather(self.scores, 1, y.unsqueeze(dim=1))
        log_z = torch.sum(torch.log(torch.exp(self.scores).add(1)), dim=1, keepdim=True)
        mc_log_z = torch.log(torch.sum(torch.exp(self.scores), dim=1, keepdim=True))
        log_diff = torch.sub(log_z, mc_log_z)
        
        if normalize:
            n=self.scores.shape[1]
        else:
            n=1

        base_loss = torch.mean(torch.sub(log_z, energies)).div(n)
        if self.beta != 0:
            mc_loss = torch.mean(torch.sub(torch.add(log_z.mul(1+self.beta), mc_log_z.mul(-self.beta)), energies)).div(n)
        else:
            mc_loss = torch.mean(torch.sub(log_z, energies)).div(n)
        
        return mc_loss, base_loss, log_diff

    def predict(self, x=None, scores=None):
        if scores is None:
            if not(x is None):
                self.scores = self.model(x)
            elif self.scores is None:
                raise Exception('No input given')
        else:
            self.scores=scores

        _, idx_max = torch.max(self.scores, dim=1)
        nc = self.scores.shape[1]
        self.scores=None
        return F.one_hot(idx_max, num_classes=nc)