
import sys
import random
import time
import math
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import KDTree
from scipy.stats import wasserstein_distance

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd.variable import Variable
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
d_cosine = nn.CosineSimilarity(dim=-1, eps=1e-8)


class SinkhornDistance(nn.Module):
    def __init__(self, eps, max_iter, dis, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction
        self.dis = dis

    def forward(self, x, x_vector, y, y_vector, data_num_N):
        if self.dis == 'cos':
            norm_x = x / torch.norm(x, dim=1, keepdim=True)
            norm_y = y / torch.norm(y, dim=1, keepdim=True)
            C_mean = 1 - torch.matmul(norm_x, norm_y.T).cuda()
            ''' eigenvector similarity '''
            num_vector = x_vector.shape[1]
            C_vector = torch.zeros_like(C_mean)
            for i in range(C_vector.shape[0]):
                for j in range(C_vector.shape[1]):
                    head_vector = x_vector[i]
                    tail_vector = y_vector[j]
                    C_vector[i, j] = torch.sum(torch.diag(torch.matmul(head_vector, tail_vector.T)).abs()) / num_vector
            C = C_mean + 1 - C_vector


        x_points = x.shape[-2]
        y_points = y.shape[-2]
        head_num_ratio = torch.tensor(data_num_N[:x_points] / data_num_N[:x_points].sum())
        head_num_ratio = torch.nn.functional.softmax(head_num_ratio / 0.5)
        # tail_num_ratio = torch.tensor(1 - data_num_N[x_points:] / data_num_N[x_points:].sum())
        tail_num_ratio = torch.tensor(
             [(1.0 / i) / np.sum(1.0 / data_num_N[x_points:]) for i in data_num_N[x_points:]])
        tail_num_ratio = torch.nn.functional.softmax(tail_num_ratio / 0.5)
        # head_num_ratio = np.array(
        #     [(1.0 / i) / np.linalg.norm(1.0 / data_num_N[:x_points]) for i in data_num_N[:x_points]])
        # head_num_ratio = torch.nn.functional.softmax(torch.tensor(head_num_ratio) / 1.)
        # tail_num_ratio = np.array(
        #     [(1.0 / i) / np.linalg.norm(1.0 / data_num_N[x_points:]) for i in data_num_N[x_points:]])
        # tail_num_ratio = torch.nn.functional.softmax(torch.tensor(tail_num_ratio) / 1.)

        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).to(device).squeeze()
        # mu = torch.tensor(head_num_ratio, dtype=torch.float,
        #                  requires_grad=False).to(device).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).to(device).squeeze()
        # nu = torch.tensor(tail_num_ratio, dtype=torch.float,
        #                  requires_grad=False).to(device).squeeze()

        u = torch.zeros_like(mu).to(device)
        v = torch.zeros_like(nu).to(device)

        actual_nits = 0
        thresh = 1e-1

        for i in range(self.max_iter):
            u1 = u  
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                break

        U, V = u, v
        pi = torch.exp(self.M(C, U, V))
        prob_pi = y_points * pi               # transfer to prob
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()
        return pi.detach().cpu(), prob_pi.detach().cpu(), cost

    def M(self, C, u, v):
        # "Modified cost for logarithmic updates"
        # "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps


    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1
