import math

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import pickle

from functools import partial
from einops import rearrange

def find_prob(yM, yB, loss_type):
    y = nn.functional.one_hot(yB, yM.size(1)).double()
    y2 = y.ge(.5)
    if loss_type == 'mse':
        g = torch.masked_select(yM, y2)
        g = torch.reshape(g, [yM.size(0)])
        g2 = 2-g
        g = torch.min(g, g2)
        g = torch.sum(g)
    elif loss_type == 'ce':
        g = torch.masked_select(nn.functional.softmax(yM, dim=1), y2)
        g = torch.reshape(g, [yM.size(0)])
        g = torch.sum(g)
    return g

def mse_loss(y, y_hat):
    y_hat = F.one_hot(y_hat, 10)
    y_hat = y_hat.type(y.dtype)
    return torch.mean((y - y_hat)**2)
