import torch
import torch.nn.functional as F
from torch.distributions import Binomial


def create_semisup_graph(log_probs, targets=None, 
                         drop_path=False, min_num_paths=1, threshold=0.95):
    b, k = log_probs.shape

    if drop_path:
        probs = torch.exp(log_probs)
        binomial_dist = Binomial(torch.ones_like(log_probs), probs)
        binomial_mask = binomial_dist.sample()
        topk_values, topk_indices = torch.topk(probs, min_num_paths, dim=1)

        # Create paths tensor with shape like probs, and set topk_indices to 1.0
        paths = torch.zeros_like(probs)
        paths.scatter_(1, topk_indices, 1.0)

        # Combine paths with binomial_mask
        paths.masked_fill_(binomial_mask.bool(), 1.0)

        # Create all_paths_mask and combine it with paths
        all_paths_mask = (topk_values[:, 0] >= threshold).view(-1, 1).expand_as(paths)
        paths.masked_fill_(all_paths_mask.bool(), 1.0)
    else:
        paths = None
    return paths


def create_partial_label_graph(log_probs, targets, 
                               drop_path=False, min_num_paths=5):

    b, k = log_probs.shape
    paths = targets > 0
    
    if drop_path:
        # TODO: change this 
        binomial_dist = Binomial(torch.ones_like(log_probs), torch.exp(log_probs))
        binomial_mask = binomial_dist.sample()
        min_num_paths = min(min_num_paths, int(targets.sum(dim=-1).mean().item()))
        _, topk_indices = torch.topk(log_probs, min_num_paths, dim=1)
        paths = torch.logical_or(F.one_hot(topk_indices, num_classes=k).float(), paths)

    return paths



def create_noisy_label_graph(log_probs, targets=None, drop_path=False, min_num_paths=5):
    b, k = log_probs.shape

    if drop_path:
        # TODO: change this 
        binomial_dist = Binomial(torch.ones_like(log_probs), torch.exp(log_probs))
        binomial_mask = binomial_dist.sample()
        topk_values, topk_indices = torch.topk(log_probs, min_num_paths, dim=1)
        paths = F.one_hot(topk_indices, num_classes=k)
    else:
        paths = None
    return paths


def create_weak_label_graph():
    pass



class NFA:
    def __init__(self, label_config, drop_path=False, min_num_paths=5):
        self.lable_config = label_config
        if label_config == 'semisup':
            self.create_label_graph = create_semisup_graph
        elif label_config == 'partial_label':
            self.create_label_graph = create_partial_label_graph
        elif label_config == 'noisy_label':
            self.create_label_graph = create_noisy_label_graph
        elif label_config == 'weak_label':
            raise NotImplementedError
        else:
            raise NotImplementedError
        self.drop_path = drop_path
        self.min_num_path = min_num_paths


    def compute(self, log_probs, targets=None):

        b, k = log_probs.size()
        
        # create paths 
        paths = self.create_label_graph(log_probs, targets, self.drop_path, self.min_num_path)
        if paths is None:
            avg_paths = log_probs.shape[-1]
        else:
            avg_paths = paths.float().sum(dim=-1).mean()

        # foward
        log_alpha = torch.full((k, b), -float("Inf"),  device=log_probs.device)
        if paths is not None:
            pos_paths = torch.nonzero(paths[0, :], as_tuple=True)[0]
            log_alpha[pos_paths, 0] = log_probs[0, pos_paths]
        else:
            log_alpha[:, 0] = log_probs[0, :]
        for i in range(1, b):
            if paths is not None:
                pos_paths = torch.nonzero(paths[i - 1, :], as_tuple=True)[0]
                preceding = log_alpha[pos_paths, i - 1]
            else:
                preceding = log_alpha[:, i - 1]
            m = torch.max(preceding)
            preceding = torch.log(torch.sum(torch.exp(preceding - m))) + m
            if paths is not None:
                pos_paths_ = torch.nonzero(paths[i, :], as_tuple=True)[0]
                log_alpha[pos_paths_, i] = preceding + log_probs[i, pos_paths_]
            else:
                log_alpha[:, i] = preceding + log_probs[i, :]

        # backward
        log_beta = torch.full((k, b), -float("Inf"),  device=log_probs.device)
        if paths is not None:
            pos_paths = torch.nonzero(paths[-1, :], as_tuple=True)[0]
            log_beta[pos_paths, -1] = log_probs[-1, pos_paths]
        else:
            log_beta[:, -1] = log_probs[-1, :]
        for i in range(b - 2, -1, -1):
            if paths is not None:
                pos_paths = torch.nonzero(paths[i + 1, :], as_tuple=True)[0]
                succeding = log_beta[pos_paths, i + 1]
            else:
                succeding = log_beta[:, i + 1]
            m = torch.max(succeding)
            succeding = torch.log(torch.sum(torch.exp(succeding - m))) + m
            if paths is not None:
                pos_paths_ = torch.nonzero(paths[i, :], as_tuple=True)[0]
                log_beta[pos_paths_, i] = succeding +  log_probs[i, pos_paths_]
            else:
                log_beta[:, i] = succeding + log_probs[i, :]
        log_beta = log_beta - log_probs.transpose(0, 1)

        gamma = log_alpha + log_beta 
        m = torch.max(gamma)
        gamma = torch.exp(gamma - m)
        gamma = gamma.transpose(0, 1)
        em_targets =  gamma / gamma.sum(dim=1, keepdim=True) 
        
        return em_targets, avg_paths


