from torch.autograd import Function
import torch
import torch.nn as nn
import torch.nn.functional as F
#from dgl.nn.pytorch.conv import GraphConv
import srgnn_utils as utils
import numpy as np
import pickle
import networkx as nx
import scipy.sparse as sp
#import dgl
from sklearn.metrics import f1_score
import torch_geometric.utils 
from IPython import embed
from torch.nn import Sigmoid, SiLU, ReLU, Tanh
import pandas as pd
from utils import NNNodeBenchmarker, index_to_mask
from basic_gnn import GCN, GAT, MLP, GraphSAGE, APPNP
from deepjdot_semi import NNNodeBenchmarker_JDOT
from cdan import NNNodeBenchmarker_CDAN

import argparse
from tqdm import tqdm
import wandb

def KMM(X,Xtest,_A=None, _sigma=1e1,beta=0.2):

    H = torch.exp(- 1e0 * pairwise_distances(X)) + torch.exp(- 1e-1 * pairwise_distances(X)) + torch.exp(- 1e-3 * pairwise_distances(X))
    f = torch.exp(- 1e0 * pairwise_distances(X, Xtest)) + torch.exp(- 1e-1 * pairwise_distances(X, Xtest)) + torch.exp(- 1e-3 * pairwise_distances(X, Xtest))
    z = torch.exp(- 1e0 * pairwise_distances(Xtest, Xtest)) + torch.exp(- 1e-1 * pairwise_distances(Xtest, Xtest)) + torch.exp(- 1e-3 * pairwise_distances(Xtest, Xtest))
    H /= 3
    f /= 3
    MMD_dist = H.mean() - 2 * f.mean() + z.mean()
    
    nsamples = X.shape[0]
    f = - X.shape[0] / Xtest.shape[0] * f.matmul(torch.ones((Xtest.shape[0],1)))
    G = - np.eye(nsamples)
    _A = _A[~np.all(_A==0, axis=1)]
    b = _A.sum(1)
    h = - beta * np.ones((nsamples,1))
    
    from cvxopt import matrix, solvers
    solvers.options['show_progress'] = False
    sol=solvers.qp(matrix(H.numpy().astype(np.double)), matrix(f.numpy().astype(np.double)), matrix(G), matrix(h), matrix(_A), matrix(b))
    return np.array(sol['x']), MMD_dist.item()

def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 0.0, np.inf)

def cmd(X, X_test, K=5):
    """
    central moment discrepancy (cmd)
    objective function for keras models (theano or tensorflow backend)
    
    - Zellinger, Werner, et al. "Robust unsupervised domain adaptation for
    neural networks via moment alignment.", TODO
    - Zellinger, Werner, et al. "Central moment discrepancy (CMD) for
    domain-invariant representation learning.", ICLR, 2017.
    """
    x1 = X
    x2 = X_test
    mx1 = x1.mean(0)
    mx2 = x2.mean(0)
    sx1 = x1 - mx1
    sx2 = x2 - mx2
    dm = l2diff(mx1,mx2)
    scms = [dm]
    for i in range(K-1):
        # moment diff of centralized samples
        scms.append(moment_diff(sx1,sx2,i+2))
        #scms+=moment_diff(sx1,sx2,1)
    return sum(scms)

def l2diff(x1, x2):
    """
    standard euclidean norm
    """
    return (x1-x2).norm(p=2)

def moment_diff(sx1, sx2, k):
    """
    difference between moments
    """
    ss1 = sx1.pow(k).mean(0)
    ss2 = sx2.pow(k).mean(0)
    #ss1 = sx1.mean(0)
    #ss2 = sx2.mean(0)
    return l2diff(ss1,ss2)

class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

class ToyGNN(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super(ToyGNN, self).__init__()
        self.layers = nn.ModuleList()
        self.g = g
        #print(in_feats, n_hidden, n_classes)
        # input layer
        self.layers.append(GraphConv(in_feats, n_hidden, activation=None))

        # hidden layers
        self.activation = activation
        for i in range(n_layers-1):
            self.layers.append(GraphConv(n_hidden, n_hidden, activation=None))
        # output layer hidden units -> n_classes
        self.layers.append(GraphConv(n_hidden, n_classes, activation=None)) # activation None
        self.fcs = nn.ModuleList([nn.Linear(n_hidden, n_hidden, bias=True), nn.Linear(n_hidden, 2, bias=True)])
        self.disc = GraphConv(n_hidden, 2, activation=None)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, features):
        h = features
        for idx, layer in enumerate(self.layers[:-1]):
            h = layer(self.g, h)
            h = self.activation(h)
            h = self.dropout(h)
        self.h = h

        return self.layers[-1](self.g, h)
    
    def dann_output(self, idx_train, iid_train, alpha=1):
        reverse_feature = ReverseLayerF.apply(self.h, alpha)
        dann_loss = xent(self.disc(self.g, reverse_feature)[idx_train,:], torch.ones_like(labels[idx_train])).mean() + xent(self.disc(self.g, reverse_feature)[iid_train,:], torch.zeros_like(labels[iid_train])).mean()
        return dann_loss
    
    def shift_robust_output(self, idx_train, iid_train, alpha = 1):
        return alpha * cmd(self.h[idx_train, :], self.h[iid_train, :])

    def output(self, features):
        h = features
        for layer in self.layers[:-1]:
            h = layer(self.g, h)
        return h

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type = float , default = 0.01)
    parser.add_argument('--alpha', type = float , default = 0.1)
    parser.add_argument('--beta', type = float , default = 0.1)
    parser.add_argument('--dataset', default = 'cora')
    parser.add_argument('--log', action='store_true')
    parser.add_argument('--gpu', type = int , default = 0)
    args = parser.parse_args()

    DATASET = args.dataset
    EPOCH = 200
    # option of 'SRGNN','DANN' and None
    METHOD = 'SRGNN'
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else "cpu")
    adj, features, one_hot_labels, ori_idx_train, idx_val, idx_test = utils.load_data(DATASET)
    nx_g = nx.Graph(adj+ sp.eye(adj.shape[0]))
    g = torch_geometric.utils.from_networkx(nx_g)
    #g = dgl.from_networkx(nx_g).to(device)
    labels = torch.LongTensor([np.where(r==1)[0][0] if r.sum() > 0 else -1 for r in one_hot_labels]) #.to(device)
    g.x = torch.FloatTensor(utils.preprocess_features(features)).to(device)
    g.edge_index = g.edge_index.to(device)
    g.y = labels.to(device)
    xent = nn.CrossEntropyLoss(reduction='none')
    
    GCN_benchmark = NNNodeBenchmarker_JDOT(arch='I_GCN', model_class=APPNP, benchmark_params={'lr': args.lr, 'epochs': 200}, h_params={'alpha':0.1, 'iterations':10, 'in_channels':g.x.shape[1], 'hidden_channels':32, 'dropout':0.0, 'num_layers':1, 'out_channels':labels.max().item() + 1, 'act':Tanh()}, device=device)    #GCN
    #embed()

    plot_data = {'loss':[], 'method':[], 'test_roc':[], 'test_macro_f1':[], 'val_roc':[], 'test_acc':[]}
    #for delta in range(10):
    data_specs = pickle.load(open('dataset/ood_semi/localized_seeds_{}.p'.format(DATASET), 'rb'))
    torch.set_num_threads(6)
    
    idx_test = torch.LongTensor(idx_test).to(device)
    for train_test in tqdm(data_specs):
        idx_train = torch.LongTensor(train_test).to(device)
        if args.dataset == 'cora':
            idx_test = torch.LongTensor(list(set(range(g.num_nodes)) - set(train_test))).to(device)
        
        #embed()
        perm = torch.randperm(idx_test.shape[0])
        iid_train = idx_test[perm[:idx_train.shape[0]]]
        train_mask = index_to_mask(idx_train, g.num_nodes)
        test_mask = index_to_mask(idx_test, g.num_nodes)
        val_mask = index_to_mask(torch.LongTensor(idx_val), g.num_nodes)
        
        GCN_benchmark.reset_parameters()

        GCN_benchmark.SetMasks(train_mask, val_mask, test_mask)
        GCN_benchmark.idx_test = idx_test
        losses, test_res = GCN_benchmark.train(g, g, 'accuracy', False, run, alpha=args.alpha, beta=args.beta) # None as second parameter meaning no adaptation

        #test_res = GCN_benchmark.test(g, test_on_val=False, da=True)
        plot_data['loss'].append(np.log(test_res['logloss']+1e-8))
        #plot_data['val_roc'].append(val_res['rocauc_ovr'])
        plot_data['test_roc'].append(test_res['rocauc_ovr'])
        plot_data['test_acc'].append(test_res['f1_micro'])
        plot_data['test_macro_f1'].append(test_res['f1_macro'])
        plot_data['method'].append('GCN')
    #plot_df = pd.DataFrame(data=plot_data)
    print("roc_auc:", np.mean(plot_data['test_roc']), np.std(plot_data['test_roc']))
    print("micro f1:", np.mean(plot_data['test_acc']), np.std(plot_data['test_acc']))
    print("macro f1:", np.mean(plot_data['test_macro_f1']), np.std(plot_data['test_macro_f1']))
    #print(f"GCN mean:{plot_df.loc[plot_df['method'] == 'GCN', 'val_roc'].mean()}, std:{plot_df.loc[plot_df['method'] == 'GCN', 'val_roc'].std()}")
