import torch
import torchvision
from PIL import Image
import torch.nn as nn
import numpy as np
import torch
import torch.nn as nn
import ot 


def optimize_ot(PS, prop):
    N, K = PS.shape

    a = prop  # K 
    b = np.ones(N) / N # N 
    a = np.array(a,dtype=float)
    b = np.array(b,dtype=float)
    M = -np.log(PS + 1e-10)
    #M = PS
    #M /= M.max() 

    T = ot.emd(b, a, M)  # KxN
    #T = ot.sinkhorn(a, b, M, 0.0001, verbose=True)
    #T = ot.sinkhorn(b, a, M, 0.00005, method = 'sinkhorn_stabilized')
    #print (T)
    argmaxes = np.argmax(T, 1)  # size N
    out = np.eye(10)[argmaxes]
    return out


def optimize_ot_cifar100(PS, prop):
    N, K = PS.shape

    a = prop  # K 
    b = np.ones(N) / N # N 
    a = np.array(a,dtype=float)
    b = np.array(b,dtype=float)
    M = -np.log(PS + 1e-10)
    #M = PS
    #M /= M.max() 

    T = ot.emd(b, a, M)  # KxN
    #T = ot.sinkhorn(a, b, M, 0.0001, verbose=True)
    #T = ot.sinkhorn(b, a, M, 0.00005, method = 'sinkhorn_stabilized')
    #print (T)
    argmaxes = np.argmax(T, 1)  # size N
    out = np.eye(100)[argmaxes]
    return out


def optimize_ot_soft(PS, prop):
    N, K = PS.shape

    a = prop  # K 
    b = np.ones(N) / N # N 
    a = np.array(a,dtype=float)
    b = np.array(b,dtype=float)
    M = -np.log(PS + 1e-10)

    output = ot.emd(b, a, M)  # KxN

    return output

def opt_sk(model, epoch, trainloader):
    selflabels = np.zeros(50000, dtype=np.float)

    for batch_idx, (data, label_input, index) in enumerate(trainloader):
        data = data.cuda()
        output_feat,output = model(data)
        p = nn.functional.softmax(output, 1)
        #PS[index, :] = p.detach().cpu().numpy()
        PS_batch = p.detach().cpu().numpy()
        prop = np.bincount(label_input,minlength=10)/data.size(0)
        argmaxes = optimize_ot(PS_batch,prop)
        #prop_new = np.bincount(argmaxes,minlength=10)/data.size(0)
        #print (prop_new)
        #print (prop)
        selflabels[index] = argmaxes

    selflabels = torch.LongTensor(selflabels).cuda()

    return selflabels

def opt_sk_new(model, epoch, trainloader):
    selflabels = np.zeros(50000, dtype=np.float)

    for batch_idx, (data, label_input, index) in enumerate(trainloader):
        data = data.cuda()
        output_feat,output = model(data)
        p = nn.functional.softmax(output, 1)
        #PS[index, :] = p.detach().cpu().numpy()
        PS_batch = p.detach().cpu().numpy()
        prop = np.bincount(label_input,minlength=10)/data.size(0)
        argmaxes = optimize_ot(PS_batch,prop)
        #prop_new = np.bincount(argmaxes,minlength=10)/data.size(0)
        #print (prop_new)
        #print (prop)
        selflabels[index] = argmaxes

    selflabels = torch.LongTensor(selflabels)
    selflabels = selflabels.view(50000,-1)
    selflabels_onehot = torch.zeros(selflabels.size(0), 10).scatter_(1, selflabels, 1).cuda()

    return selflabels_onehot

def opt_sk_rotation(model, epoch, trainloader):
    selflabels = np.zeros(50000, dtype=np.float)

    for batch_idx, (data, label_input, index,_,_) in enumerate(trainloader):
        data = data.cuda()
        output_feat,output = model(data)
        p = nn.functional.softmax(output, 1)
        #PS[index, :] = p.detach().cpu().numpy()
        PS_batch = p.detach().cpu().numpy()
        prop = np.bincount(label_input,minlength=10)/data.size(0)
        argmaxes = optimize_ot(PS_batch,prop)
        #prop_new = np.bincount(argmaxes,minlength=10)/data.size(0)
        #print (prop_new)
        #print (prop)
        selflabels[index] = argmaxes

    selflabels = torch.LongTensor(selflabels).cuda()

    return selflabels