# Logic for MCMI loss function
import torch.nn as nn
import torch.nn.functional as F
import torch
import json


def load_vector_from_file(file_path):
    with open(file_path, "r") as f:
        vector = json.load(f)
    return torch.tensor(vector, dtype=torch.long)


class CMILoss(nn.Module):
    def __init__(self):
        super(CMILoss, self).__init__()

    def forward(self, output, centroids):
        # centroids = centroids / (centroids.sum(1)[:,None])
        cmi_value = F.kl_div(centroids.log(), F.log_softmax(output,dim=1), reduction="batchmean", log_target=True)
        return cmi_value


class MCMILoss(nn.Module):
    def __init__(self, dataset):
        super(MCMILoss, self).__init__()
        self.ce_criterion = nn.CrossEntropyLoss()
        self.dataset = dataset
        
    def forward(self, logits, centroids):
        bs, num_classes = logits.shape  # [batch_size, num_classes]
        probs = F.softmax(logits, dim=1)  # [batch_size, num_classes]
        
        # to get ny
        class_counts = load_vector_from_file("./save/{}/train_class_counts.json".format(self.dataset)).to(logits.device) # [num_classes]
        class_counts = class_counts.unsqueeze(0).expand(bs, num_classes)  # [batch_size, num_classes]

        # MCMI loss
        q_value = probs / (centroids * class_counts) # [batch_size, num_classes]
        mcmi_loss = - torch.sum(probs * torch.log(q_value)) / bs # scaler

        return mcmi_loss
