

import os
from torch.utils.data import Dataset

import h5py

import torch

from torchvision import transforms
from PIL import Image

from torch.utils.data import DataLoader,Dataset

import numpy as np


import torch
import torch.nn as nn
import  torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial

import librosa

import pickle

import librosa.display

from collections import Counter

from torchsummary import summary


import cv2

import torchvision


def mle(f_no_missing, g_no_missing, h_no_missing, f_missing, h_missing,  h_eye, p):

    # f_no_missing is the features of one modality of the modality-complete samples
    # g_no_missing is the features of another modality of the modality-complete samples
    # h_no_missing is the features of label of the modality-complete samples
    # f_missing is the features of one modality of the modality-missing samples
    # h_missing is the features of label of the modality-missing samples
    # h_eye is the feartures of each category label.
    # p in the  empirical label distribution \hat{R}_Z
    #  addition is used as the form of \phi.

    num_class=len(p)

    f_g_combine = f_no_missing + g_no_missing
    ip = torch.matmul(f_g_combine, torch.transpose(h_eye, 0, 1))
    ip_max = torch.max(ip, 1)[0]
    ip_exp  = torch.exp(ip - ip_max.unsqueeze(1).expand(-1, num_class))*p
    z_given_xy_loss = - torch.sum(torch.sum( (f_no_missing + g_no_missing)* h_no_missing, 1))  +  torch.sum(torch.log(torch.sum(ip_exp, 1)) + ip_max )

    
    f_h_combine_missing = torch.sum(f_missing * h_missing,1).unsqueeze(1)
    ip_f_h_combine_missing = torch.matmul(h_missing, torch.transpose(g_no_missing, 0, 1)) +  f_h_combine_missing
    ip_missing_max = torch.max(ip_f_h_combine_missing , 1)[0]
    ip_missing_exp  = torch.exp(ip_f_h_combine_missing - ip_missing_max.unsqueeze(1).expand(-1, ip_f_h_combine_missing.size()[1]))
    ip_loss1 = torch.log(torch.mean(ip_missing_exp,1)) + ip_missing_max
    z_given_x_loss1 = - torch.sum(ip_loss1)

    f_loss = torch.zeros([len(f_missing)]).cuda()
    p_new= p.unsqueeze(0)
    for i in range(len(f_missing)):
        temp_f_h = torch.sum(f_missing[i]* h_eye, 1)
        temp_g_h = (g_no_missing.unsqueeze(1) * h_eye.unsqueeze(0)).sum(dim=-1) 
        temp = temp_f_h + temp_g_h
        temp_max = torch.max(temp)
        temp_exp = (p_new* (temp - temp_max).exp()).sum(-1).mean()
        f_loss[i] = torch.log(temp_exp)+ temp_max

    z_given_x_loss2 = torch.sum(f_loss)

    z_given_x_loss =  z_given_x_loss1 + z_given_x_loss2

    loss = z_given_xy_loss   +   z_given_x_loss

    return loss





    




