import torch
import torch.nn as nn
from .AutoVisualMultiGINComplete import AutoVisualNet
import ot
import scipy as sp

class L2OGWDMultiGINComplete(AutoVisualNet):
    def __init__(self, input_dim, gnn_hidden, gnn_out, out_dim,
                                   n_graph_view, n_transformer, device):
        super(L2OGWDMultiGINComplete, self).__init__(input_dim=input_dim, gnn_hidden=gnn_hidden, gnn_out=gnn_out, out_dim=out_dim,
                                                     n_graph_view=n_graph_view, n_transformer=n_transformer, device=device)

    @staticmethod
    def rbf_kernel(X1, X2, gamma=None):
        """
        Compute the RBF kernel matrix between two sets of data.
        Args:
            X1: torch.Tensor, (n_samples_1, n_features), the first set of data.
            X2: torch.Tensor, (n_samples_2, n_features), the second set of data.
        Returns:
            K: torch.Tensor, (n_samples_1, n_samples_2), the RBF kernel matrix.
        """
        D = torch.cdist(X1, X2)
        if gamma is None:
            raise ValueError('gamma should be specified.')
        K = torch.exp(-gamma * D**2)
        return K
    
    def mmd(self, X1, X2, gamma=None):
        """
        Compute the maximum mean discrepancy between the two distributions.
        Args:
            X1: torch.Tensor, (n_samples_1, n_features), the first distribution.
            X2: torch.Tensor, (n_samples_2, n_features), the second distribution.
        Returns:
            mmd: torch.Tensor, the maximum mean discrepancy between the two distributions.
        """
        n1, n2 = X1.size(0), X2.size(0)
        if gamma is None:
            M = torch.cat([X1, X2], dim=0)
            D = torch.cdist(M, M)
            gamma = 1.0 / (2 * D.mean()**2)
        K11 = self.rbf_kernel(X1, X1, gamma)
        K22 = self.rbf_kernel(X2, X2, gamma)
        K12 = self.rbf_kernel(X1, X2, gamma)
        mmd = K11.sum() / (n1 * (n1 - 1)) + K22.sum() / (n2 * (n2 - 1)) - 2 * K12.sum() / (n1 * n2)
        return mmd
    
    @staticmethod
    def gwd_fn(X1, X2):
        C1 = sp.spatial.distance.cdist(X1, X1)
        C2 = sp.spatial.distance.cdist(X2, X2)

        C1 /= C1.max()
        C2 /= C2.max()

        p = ot.unif(C1.shape[0])
        q = ot.unif(C2.shape[0])

        # conditional gradient algorithm
        _, log0 = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=False, log=True)
        gwd = log0['gw_dist']
        return gwd


    def loss_gwd(self, x1, graph1, x2, graph2):
        """
        Compute the mse loss between the original and the reconstructed data.
        """
        emb1 = self.forward(x1, graph1)
        emb2 = self.forward(x2, graph2)
        # print(f'emb1: {emb1.shape}, emb2: {emb2.shape}')
        mmd = self.mmd(emb1, emb2)
        x1_stack, x2_stack = sum(x1[0]), sum(x2[0])
        gwd = self.gwd_fn(x1_stack, x2_stack)
        mse_loss = torch.abs(mmd - gwd)
        # print(f'[mmd gwd loss] = [{mmd.item()} {gwd.item()} {mse_loss.item()}]')
        return mse_loss