import torch
from torch import Tensor
import torch.nn.functional as F

from heapq import heappop, heappush

def caching_reset(func):
    name=func.__name__
    def cached_func(self, *args, first_call=True, **kwargs):
        print("In decorated function")
        if name in self.cache.keys():
            if not(self.cache[name] is None):
                return self.cache[name]

        output = func(self, *args, **kwargs)

        if first_call:
            self.clear_cache(key=name)
        else:
            self.cache[name] = output

        return output

    return cached_func

def caching(func):
    name=func.__name__
    def cached_func(self, *args, **kwargs):
        if name in self.cache.keys():
            if not(self.cache[name] is None):
                return self.cache[name]

        self.cache[name] = func(self, *args, **kwargs)

        return self.cache[name]

    return cached_func

def logsumexp(tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor:
    with torch.no_grad():
        m, _ = torch.max(tensor, dim=dim, keepdim=True)
        m = m.masked_fill_(torch.isneginf(m), 0.)

    z = (tensor - m).exp_().sum(dim=dim, keepdim=True)
    mask = z == 0
    z = z.masked_fill_(mask, 1.).log_().add_(m)
    z = z.masked_fill_(mask, -float('inf'))

    if not keepdim:
        z = z.squeeze(dim=dim)
    return z

def logsum(logprobs):
    return torch.log(torch.sum(torch.exp(logprobs), dim=1))

# @torch.jit.script
def log1mexp(logprobs):
    return torch.log1p(-torch.exp(logprobs))


def enumerate_states(variables):
    if len(variables)==1:
        yield {variables[0]:1}
        yield {variables[0]:-1}
    else:
        for state in enumerate_states(variables[1:]):
            yield {**state, **{variables[0]:1}}
            yield {**state, **{variables[0]:-1}}


def ind_topk(probs, k=2, boolean=False, variables=None):
    bs, n = probs.shape

    if variables is None:
        variables = list(range(n))

    states = torch.zeros((bs, n, 1), dtype=torch.bool)
    p = torch.ones((bs, 1))
    for i in variables:
        value = F.one_hot(torch.Tensor([i]).long(), num_classes=n).unsqueeze(-1).expand((bs, n, states.shape[2]))
        states = torch.cat([torch.add(states, value), torch.add(states, -value)], dim=2)
        p = torch.cat([torch.mul(p, probs[:, i].unsqueeze(-1)), torch.mul(p, 1-probs[:, i].unsqueeze(-1))], dim=1)
        if states.shape[2] > k:
            p, idx = torch.topk(p, k=k, dim=1)
            idx=idx.unsqueeze(1).expand(-1, states.shape[1], k)
            states = torch.gather(states, dim=2, index=idx)
        else:
            p, idx = torch.sort(p, dim=1, descending=True)
            idx=idx.unsqueeze(1).expand(-1, states.shape[1], states.shape[2])
            states = torch.gather(states, dim=2, index=idx)
    
    if boolean:
        states = states.ge(0)

    return states, p

def enumerate(probs, thresh):
    n = probs.shape[0]
    maxs = torch.where(probs.ge(0.5), probs, 1-probs)
    maxs = torch.stack([torch.prod(maxs[j:]) for j in range(n)], dim=0)
    maxs = torch.concat([maxs, torch.ones(1)], dim=0)
    mpe = maxs[0]
    if mpe>=thresh:
        queue = [(mpe, torch.zeros(n))]
        while queue:
            m, gamma = heappop(queue)
            for j in range(int(gamma.abs().sum()), n):
                if probs[j]>=0.5:
                    p = m*(1-probs[j])*maxs[j+1]
                    if p>=thresh:
                        gamma[j]=-1
                        heappush(queue, (p, gamma.clone()))
                    gamma[j]=1
                else:
                    p = m*probs[j]*maxs[j+1]
                    if p>=thresh:
                        gamma[j]=1
                        heappush(queue, (p, gamma.clone()))
                    gamma[j]=-1
            yield gamma.ge(0)
    
