import torch
import torch.nn.functional as F
# import numpy as np

# def calc_mutual_info(logits, labels, stochastic_label=False, input_probs_not_logits = False):
#     """Calculate I(X;T) and I(T;Y) from logits and labels

#     Here, we assume p(X_i) = 1/N with the number of data samples N.

#     Args:
#         logits (torch.tensor): N x d tensor with the number of classes d
#         labels (torch.tensor): N or N x d tensor with the number of classes d
#         stochastic_label (bool): Is labels N x d tensor? (default: False)
#         input_probs_not_logits (bool): Input probs instead of logits  (default: False)

#     Returns:
#         tupple: I(X;T) and I(T;Y)

#     """
#     N = len(labels)
    
#     if stochastic_label:
#         p_y_given_x = labels #p(Y|x_i)
#     else:
#         p_y_given_x = F.one_hot(labels, num_classes=logits.shape[1]).float() #p(Y|x_i)
#     p_y = torch.mean(p_y_given_x, dim=0) #P(Y)
    
#     if input_probs_not_logits:
#         p_t_given_x = logits #p(T|x_i)
#     else:
#         p_t_given_x = F.softmax(logits, dim=1) #p(T|x_i)
    
#     p_t = torch.mean(p_t_given_x, dim=0) #p(T)
#     h_t = torch.sum(-p_t*torch.log(p_t)) #H(T)
#     p_t_given_y = p_y_given_x.T@p_t_given_x/(N*p_y) #p(T|Y)
    
#     h_t_given_x = torch.sum(-p_t_given_x*torch.log(p_t_given_x))/len(p_t_given_x) #H(T|X)
#     h_t_given_y = torch.sum(-p_t_given_y*torch.log(p_t_given_y))/len(p_t_given_y) #H(T|Y)
    
#     i_x_t = h_t - h_t_given_x #I(X;T)
#     i_t_y = h_t - h_t_given_y #I(T;Y)
    
#     return (i_x_t, i_t_y)



def calc_mutual_info(logits, labels, stochastic_label=False, input_probs_not_logits=False):
    """Calculate empirical mutual infomation I(X;T) and I(T;Y) from logits and labels.

    Here, we assume p(X_i) = 1/N with the number of input data samples N.

    Args:
        logits (torch.Tensor): N x d tensor with the number of classes d.
        labels (torch.Tensor): N or N x d tensor with the number of classes d.
        stochastic_label (bool): Is labels N x d tensor? (default: False)
        input_probs_not_logits (bool): Input probs instead of logits? (default: False)

    Returns:
        tuple: I(X;T) and I(T;Y)
    """
    N = len(labels)
    
    # p(Y|X)
    if stochastic_label:
        p_y_given_x = labels 
    else:
        p_y_given_x = F.one_hot(labels, num_classes=logits.shape[1]).float()  
    
    # p(Y) computation by marginalization
    p_y = torch.mean(p_y_given_x, dim=0) 
    
    # p(T|X)
    if input_probs_not_logits:
        p_t_given_x = logits  
    else:
        p_t_given_x = F.softmax(logits, dim=1) 
    
    # p(T|Y) computation by Bayes' theorem and Markov chain of IB
    p_t_y = p_y_given_x.T @ p_t_given_x / N   #p(T,Y)
    p_t_given_y = p_t_y / (p_y.unsqueeze(1) + 1e-10)  #p(T|Y)
        
    # p(T) computation by marginalization
    p_t = torch.mean(p_t_given_x, dim=0) 

    # Entropy H(T)
    h_t = torch.sum(-p_t * torch.log(p_t + 1e-10))
    
    # Conditional Entropies
    h_t_given_x = torch.sum(-p_t_given_x * torch.log(p_t_given_x + 1e-10)) / N  # H(T|X)
    h_t_given_y = torch.sum(-p_t_y * torch.log(p_t_given_y + 1e-10))  # H(T|Y)

    # Mutual Information
    i_x_t = h_t - h_t_given_x  # I(X;T)
    i_t_y = h_t - h_t_given_y  # I(T;Y)

    return i_x_t, i_t_y

    

def logits_labels(data_loader, net, device):
    all_logits, all_labels = [], []
    net.to(device)
    net.eval()
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits = net(inputs)
            all_logits.append(logits.detach().cpu())  
            all_labels.append(labels.detach().cpu())

    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)
    
    return all_logits, all_labels


    
    
        
    