import torch
import os.path as osp
import aug.augmentors as A
import torch_geometric.transforms as T
import torch.nn.functional as F

from torch import nn
from tqdm import tqdm
from torch.optim import Adam
from aug.eval import get_split, LREvaluator
from torch_geometric.nn import GCNConv
from torch_geometric.nn.inits import uniform
from torch_geometric.datasets import Planetoid

import copy
import gc

import ot
from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein2, gromov_wasserstein, fused_gromov_wasserstein
from torch_geometric.utils import to_scipy_sparse_matrix, to_dense_adj
# from torchmetrics.functional import pairwise_cosine_similarity
from geomloss import SamplesLoss  # See also ImagesLoss, VolumesLoss
import numpy as np
from ot.gromov._utils import init_matrix, gwloss, gwggrad, init_matrix_semirelaxed, tensor_product
from ot.backend import get_backend

device = torch.device('cuda:0')


def f(G):
    return 0.5 * torch.sum(G**2)


def df(G):
    return G


class GConv(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(GConv, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.activation = nn.PReLU(hidden_dim)
        for i in range(num_layers):
            if i == 0:
                self.layers.append(GCNConv(input_dim, hidden_dim))
            else:
                self.layers.append(GCNConv(hidden_dim, hidden_dim))

    def forward(self, x, edge_index, edge_weight=None):
        z = x
        for conv in self.layers:
            z = conv(z, edge_index, edge_weight)
            z = self.activation(z)
        return z


class Encoder(torch.nn.Module):
    def __init__(self, encoder1, encoder2, augmentor, hidden_dim):
        super(Encoder, self).__init__()
        self.encoder1 = encoder1
        self.encoder2 = encoder2
        self.augmentor = augmentor
        self.project = torch.nn.Linear(hidden_dim, hidden_dim)
        uniform(hidden_dim, self.project.weight)
        self.linear = torch.nn.ModuleList(
            [torch.nn.Linear(hidden_dim + 512, hidden_dim) for _ in range(1)])

    @staticmethod
    def corruption(x, edge_index, edge_weight):
        return x[torch.randperm(x.size(0))], edge_index, edge_weight
    
    def comp(self, u, v, m):
        return torch.exp((u.t()+v-m)*1/0.1)

    def forward(self, x, edge_index, mode, sigma=1, rho=1, edge_weight=None):

        aug1, aug2 = self.augmentor

        x1, edge_index1, edge_weight1, _ = aug1(x, edge_index, edge_weight)
        x2, edge_index2, edge_weight2, subset = aug2(x, edge_index, edge_weight)
        
        z1 = self.encoder1(x1, edge_index1, edge_weight1)

        z2 = self.encoder2(x2, edge_index2, edge_weight2)

        g1 = self.project(torch.sigmoid(z1.mean(dim=0, keepdim=True)))
        g2 = self.project(torch.sigmoid(z2.mean(dim=0, keepdim=True)))
        
        if mode == 'train':

            # edge
            C1 = torch.squeeze(to_dense_adj(edge_index1))
            # attr
            F1 = x1
            N1l = x1.shape[0]
            N1r = x1.shape[1]
            h1 = ot.unif(N1l, type_as=x1)

            # edge
            C2 = torch.squeeze(to_dense_adj(
                edge_index2, max_num_nodes=x2.shape[0]))
            # attr
            F2 = x2
            N2l = x2.shape[0]
            N2r = x2.shape[1]
            h2 = ot.unif(N2l, type_as=x2)

            # Mp = torch.cdist(F1, F2, p=2)
            Mp = ot.dist(F1, F2, metric='euclidean')
            # Mb = torch.cdist(z1, z2, p=2)
            Mb = ot.dist(z1, z2, metric='euclidean')

            if sigma < 1:

                # P = fused_gromov_wasserstein(Mp, C1, C2, h1, h2, symmetric=True, alpha=1, log=False)
                P = semirelaxed_fused_gromov_wasserstein(
                    Mp, C1, C2, h1, symmetric=True, alpha=1-sigma, log=False, G0=None)

                nx = get_backend(h1, C1, C2)
                constC, hC1, hC2, fC2t = init_matrix_semirelaxed(
                    C1, C2, h1, loss_fun='square_loss', nx=nx)
                OM = torch.ones(N1l, N2l).to(device)
                OM = OM / (N1l * N2l)
                qOneM = nx.sum(OM, 0)
                ones_p = nx.ones(h1.shape[0], type_as=h1)
                marginal_product = nx.outer(ones_p, nx.dot(qOneM, fC2t))
                Mp2 = tensor_product(constC + marginal_product, hC1, hC2, P, nx=nx)
                Mp2 = F.normalize(Mp2)
                Mp = (sigma) * Mp + (1-sigma) * Mp2

                B = ot.emd(h1, h2, Mb)

                # gw0, logP = ot.gromov.gromov_wasserstein(C1, C2, h1, h2, 'square_loss', verbose=True, log=True)

                # gw, logP = ot.gromov.entropic_gromov_wasserstein(C1, C2, h1, h2, 'square_loss', epsilon=5e-4, log=True, verbose=True)

                # B = ot.optim.cg(h1, h2, Mb, reg=reg, f=f, df=df)
                # B = ot.optim.semirelaxed_cg(h1, h2, Mb, reg=reg, f=f, df=df)

                kl_loss = nn.KLDivLoss(reduction='batchmean')
                loss = kl_loss(Mp, Mb)

                # sloss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)
                # loss = sloss(Mp, Mb)

                # loss = torch.linalg.matrix_norm(Mp - Mb, ord='fro')

                loss = rho * loss + torch.linalg.matrix_norm(P - B, ord='fro')

            elif sigma == 1:
                # speed up
                sl = SamplesLoss(loss='sinkhorn', p=2, debias=True, blur=0.1**(1 / 2), backend='tensorized')
                m = 0*Mb+1*Mp
                sl.potentials = True
                u, v = sl(F1, F2)
                P = torch.exp((u.t()+v-m)*1/0.1)
                # P = self.comp(u, v, m)
                
                sl.potentials = True
                u, v = sl(z1, z2)
                # B = self.comp(u, v, m)
                B = torch.exp((u.t()+v-m)*1/0.1)

                # large data
                # P = ot.emd(h1, h2, Mp)
                # B = ot.emd(h1, h2, Mb)

                # kl_loss = nn.KLDivLoss(reduction='batchmean')
                # loss = kl_loss(Mp, Mb)
                
                # faster convergence
                sloss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)
                loss = sloss(Mp, Mb)

                # loss = torch.linalg.matrix_norm(Mp - Mb, ord='fro')

                loss = rho * loss + torch.linalg.matrix_norm(P - B, ord='fro')

            # P.requires_grad=True
            # B.requires_grad=True
        else:
            return z1, z2

        return z1, z2, loss


def train(encoder_model, data, optimizer, sigma=1, rho=1):
    encoder_model.train()
    optimizer.zero_grad()

    z1, z2, loss = encoder_model(
        data.x, data.edge_index, mode='train', sigma=sigma, rho=rho)

    loss.backward()
    optimizer.step()

    return loss.item()


def test(encoder_model, data):
    encoder_model.eval()
    z1, z2 = encoder_model(
        data.x, data.edge_index, mode='test')
    z = z1 + z2
    split = get_split(num_samples=z.size()[0], train_ratio=0.1, test_ratio=0.8)
    result = LREvaluator()(z, data.y, split)
    return result


def main():

    path = 'datasets'
    dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures())

    data = dataset[0].to(device)

    aug1 = A.Identity()
    aug2 = A.Compose([A.EdgePerturbation(pe=0.3), A.RWSampling(
        use=False, num_seeds=100, walk_length=10), A.FeatureMasking(pf=0.4), A.NodeDropping(pn=0.0)])

    gconv1 = GConv(input_dim=dataset.num_features,
                   hidden_dim=512, num_layers=2).to(device)
    gconv2 = GConv(input_dim=dataset.num_features,
                   hidden_dim=512, num_layers=2).to(device)

    encoder_model = Encoder(encoder1=gconv1, encoder2=gconv2, augmentor=(
        aug1, aug2), hidden_dim=512).to(device)

    optimizer = Adam(encoder_model.parameters(), lr=0.001)

    with tqdm(total=1000, desc='(T)') as pbar:
        for epoch in range(1, 1001):
            sigma = 1
            rho = 1
            loss = train(encoder_model, data, optimizer, sigma=sigma, rho=rho)

            pbar.set_postfix({'loss': loss})
            pbar.update()

            if epoch % 2 == 0:
                test_result = test(encoder_model, data)
                print(f'Best test ACC={test_result["acc"]:.4f}')


if __name__ == '__main__':
    main()
