import numpy as np
from sklearn.mixture import GaussianMixture
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from datautil.getdataloader import utilDataset
import datautil.imgdata.util as imgutil

def get_internal_features(self, inputs, outputs):
    global internal_classifier
    internal_classifier = outputs

def myGMMfit(args, network, dataloader):

    img_dict, clabel, dlabel = dataloader.dataset.get_raw_data()
    gmm_dataset = utilDataset(img_dict, np.array(clabel), np.array(dlabel), loader=dataloader.dataset.loader, transform=imgutil.image_test(args.dataset))
    gmm_dataloader = DataLoader(dataset=gmm_dataset, batch_size=len(gmm_dataset), num_workers=args.N_WORKERS)

    image, label, dlabel = next(iter(gmm_dataloader))
    
    handle_internal_features = network[1].fc0.register_forward_hook(get_internal_features)
    softmax = nn.Softmax(dim=1)

    network = network.cpu()
    network.eval()
    with torch.no_grad():
        pred = network(image)
    network.train()
    network = network.to(args.device)


    gmmX = internal_classifier
    gmmY = label

    yper = torch.argmax(softmax(pred), dim=1)

    # gmmX = gmmX[gmmY==yper,:]
    # gmmY = gmmY[gmmY==yper]

    gmmModel =  GaussianMixture(n_components=args.num_classes,covariance_type='full', max_iter=300,init_params='kmeans')  # Hyperparameters may be different for different dataset.


    gmmModel.fit(gmmX,gmmY)

    gmmModelSingle =  GaussianMixture(n_components=1,covariance_type='full') 

    for i in range(args.num_classes):
        a= gmmX[gmmY==i,:]
        gmmModelSingle.fit(a)
        gmmModel.weights_[i] = 1/args.num_classes #gmmY[gmmY==i].shape[0]/gmmY.shape[0]
        gmmModel.covariances_[i] = gmmModelSingle.covariances_[0]
        gmmModel.means_[i] = gmmModelSingle.means_[0]

        gmmModel.precisions_cholesky_[i] = gmmModelSingle.precisions_cholesky_[0]
        gmmModel.precisions_[i] = gmmModelSingle.precisions_[0]

    handle_internal_features.remove()

    return gmmModel

def generateTheta(L,endim):
    theta_=np.random.normal(size=(L,endim))
    for l in range(L):
        theta_[l,:]=theta_[l,:]/np.sqrt(np.sum(theta_[l,:]**2))
    return torch.from_numpy(theta_).to(dtype=torch.float)

def oneDWassersteinV3(p,q):
    # ~10 Times faster than V1

    # W2=(tf.nn.top_k(tf.transpose(p),k=tf.shape(p)[0]).values-
    #     tf.nn.top_k(tf.transpose(q),k=tf.shape(q)[0]).values)**2

    # return K.mean(W2, axis=-1)

    W2 = (torch.topk(torch.transpose(p,0,1), k=p.size()[0]).values - 
          torch.topk(torch.transpose(q,0,1), k=p.size()[0]).values)**2
    return torch.mean(W2, dim=-1)


def sWasserstein(P,Q,theta,nclass,Cp=None,Cq=None):
    lambda_=10.0
    p=torch.matmul(P,torch.transpose(theta,0,1))
    q=torch.matmul(Q,torch.transpose(theta,0,1))
    sw=lambda_*torch.mean(oneDWassersteinV3(p,q))
    if (Cp is not None) and (Cq is not None):
        for i in range(nclass):
            pi=torch.gather(p,dim=1,index=torch.squeeze(torch.where(torch.not_equal(Cp[:,i],0))))
            qi=torch.gather(q,dim=1,index=torch.squeeze(torch.where(torch.not_equal(Cq[:,i],0))))
            sw=sw+100.*torch.mean(oneDWassersteinV3(pi,qi))
    return sw
