
# Two classifier's version
# X, P(Y|X)
# X, Z, P(Y|Z)
import torch
import torch.nn as nn
import logging
import sklearn.metrics
import copy
import numpy as np
import torch.nn.functional as F
from IPython import embed
from utils import index_to_mask

from torch.autograd import Function
import ot

def l2diff(x1, x2):
    """
    standard euclidean norm
    """
    return (x1-x2).norm(p=2)

def moment_diff(sx1, sx2, k):
    """
    difference between moments
    """
    ss1 = sx1.pow(k).mean(0)
    ss2 = sx2.pow(k).mean(0)
    #ss1 = sx1.mean(0)
    #ss2 = sx2.mean(0)
    return l2diff(ss1,ss2)

def cmd(X, X_test, K=5):
    x1 = X
    x2 = X_test
    mx1 = x1.mean(0)
    mx2 = x2.mean(0)
    sx1 = x1 - mx1
    sx2 = x2 - mx2
    dm = l2diff(mx1,mx2)
    scms = [dm]
    for i in range(K-1):
        # moment diff of centralized samples
        scms.append(moment_diff(sx1,sx2,i+2))
        #scms+=moment_diff(sx1,sx2,1)
    return sum(scms)

# L2 distance
def L2_dist(x,y):
    '''
    compute the squared L2 distance between two matrics
    '''
    distx = torch.reshape(torch.sum(torch.square(x),1), (-1,1))
    disty = torch.reshape(torch.sum(torch.square(y),1), (1,-1))
    dist = distx + disty
    dist -= 2.0*torch.matmul(x, torch.transpose(y,0,1))  
    return dist

class Identity_Encoder(nn.Module):
    def __init__(self):
        super(Identity_Encoder, self).__init__()
    def forward(self, features, edges=None):
      return features
    
    def reset_parameters(self):
      return

class MLP_Encoder(nn.Module):
    def __init__(self,
                 in_feats,
                 hid_feats,
                 dropout=0):
        super(MLP_Encoder, self).__init__()
        self.mlp = nn.ModuleList([nn.Linear(in_feats, hid_feats), nn.Linear(hid_feats, hid_feats)])
        #print(dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.activation = nn.Tanh()
        #self.conv1 = TaxoHeteroGraphConv(conv1_funcs, aggregate='sum')

    def forward(self, features, edges=None):
        # inputs are features of nodes
        x = features
        for idx,layer in enumerate(self.mlp):
          if idx < len(self.mlp) - 1:
            x = self.dropout(self.activation(layer(x)))
          else:
            x = self.dropout(layer(x))
        return x
    
    def reset_parameters(self):
      for layer in self.mlp:
          layer.reset_parameters()

    def dann_output(self, idx_train, iid_train, alpha=1):
      pass
        #reverse_feature = ReverseLayerF.apply(self.h, alpha)
        #dann_loss = xent(self.disc(self.g, reverse_feature)[idx_train,:], torch.ones_like(labels[idx_train])).mean() + xent(self.disc(self.g, reverse_feature)[iid_train,:], torch.zeros_like(labels[iid_train])).mean()
        #return dann_loss
    
    def cmd(self, h_src, h_tgt, alpha = 0.1):
        return alpha * cmd(h_src, h_tgt)

# general benchmarkers
class NNNodeBenchmarker_JDOT():
  def __init__(self, arch, model_class, benchmark_params, h_params, device=None):
    #super().__init__(generator_config, model_class, benchmark_params, h_params)
    # remove meta entries from h_params
    self._epochs = benchmark_params['epochs']

    # graph as classifier
    if True:
        if arch == 'I_GCN' or arch == 'I-GCN':
            self._encoder = Identity_Encoder()
            self.classifier = model_class(**h_params)
        # by default using this version
        else:
            self._encoder = MLP_Encoder(h_params['in_channels'], h_params['in_channels'], h_params['dropout'])
            self.classifier = model_class(**h_params)
            
    else:
        self._encoder = model_class(**h_params)
        #self._encoder = nn.Identity()
        self.classifier = MLP_Encoder(h_params['out_channels'], 2)
    if device is not None:
      self._encoder = self._encoder.to(device)
      self._model = self.classifier.to(device)

    self._optimizer = torch.optim.Adam(list(self._encoder.parameters())+list(self.classifier.parameters()),
                                       lr=benchmark_params['lr'],
                                       weight_decay=5e-4)
    self._lr = benchmark_params['lr']
    self._criterion = torch.nn.CrossEntropyLoss()
    self._train_mask = None
    self._val_mask = None
    self._test_mask = None
    self._num_class = h_params['out_channels']
  def AdjustParams(self, generator_config):
    if 'num_clusters' in generator_config and self._h_params is not None:
      self._h_params['out_channels'] = generator_config['num_clusters']
  
  def reset_parameters(self):
    self._encoder.reset_parameters()
    self.classifier.reset_parameters()
    self._optimizer = torch.optim.Adam(list(self.classifier.parameters()) + list(self._encoder.parameters()),
                                       lr=self._lr,
                                       weight_decay=5e-4)

  def SetMasks(self, train_mask, val_mask, test_mask):
    self._train_mask = train_mask
    self._val_mask = val_mask
    self._test_mask = test_mask
  
  def set_lambda(self, lambd):
    self.lambd = lambd

  def reset_grad(self):
    self._optimizer.zero_grad()  # Clear feature generator gradients.

  def discrepancy(self, out1, out2):
    return torch.mean(torch.abs(F.softmax(out1, dim=-1) - F.softmax(out2, dim=-1)))

  def train_step(self, data, tgt_data, alpha=0.1, beta=0.1):
    self.classifier.eval()
    self._encoder.eval()
    with torch.no_grad():
        Z = self._encoder(data.x, data.edge_index)
        Z_t = self._encoder(tgt_data.x, tgt_data.edge_index)
        #out_s = self.classifier(Z, data.edge_index)
        out_t = self.classifier(Z_t, tgt_data.edge_index)
        
        Z = self.classifier.output(Z, data.edge_index)
        Z_t = self.classifier.output(Z_t, tgt_data.edge_index)
        C0 = torch.cdist(Z[self._train_mask], Z_t[self._test_mask], p=2.0)**2
        #embed()
        ys_cat = torch.eye(self._num_class, dtype=torch.int8)[data.y].to(data.y.device)
        C1 = torch.cdist(ys_cat[self._train_mask].float(), F.softmax(out_t[self._test_mask],-1), p=2)**2
        # JDOT ground metric

        C= alpha*C0+ beta*C1
        
        # JDOT optimal coupling (gamma)
        gamma=ot.emd(torch.FloatTensor(ot.unif(Z[self._train_mask].shape[0])), torch.FloatTensor(ot.unif(Z_t[self._test_mask].shape[0])),C).to(data.y.device)
        
        # metric block
        #self.C = (gamma*C).sum().item()
        #h_src = self.classifier(Z, data.edge_index)
        #h_tgt = out_t
        #self.cmd =  0.1 * cmd(h_src, h_tgt)

    self.classifier.train()
    self._encoder.train()
    self.reset_grad()
    Z = self._encoder(data.x, data.edge_index)
    Z_t = self._encoder(tgt_data.x, tgt_data.edge_index)
    
    #out_s = self.classifier(Z, data.edge_index)
    out_t = self.classifier(Z_t, tgt_data.edge_index)
    yt_pred = F.log_softmax(out_t,-1)
    ys_cat = torch.eye(self._num_class, dtype=torch.int8)[data.y][self._train_mask].float().to(data.y.device)
    loss_t = -torch.matmul(ys_cat, torch.transpose(yt_pred[self._test_mask],1,0))
    out_s1 = self.classifier(Z, data.edge_index)  # Perform a single forward pass.
    loss_s1 = self._criterion(out_s1[self._train_mask], data.y[self._train_mask])
    loss_t = torch.sum(gamma * loss_t)
    
    Z = self.classifier.output(Z, data.edge_index)
    Z_t = self.classifier.output(Z_t, tgt_data.edge_index)

    loss_z = torch.sum(gamma* L2_dist(Z[self._train_mask], Z_t[self._test_mask]))
    loss = loss_s1 + beta * loss_t + alpha * loss_z
    loss.backward()
    self._optimizer.step()
    return loss_s1.item(), loss_t.item(), loss_z.item()

  
  def get_embeddings(self, data, tgt_data):
    self.classifier.eval()
    self._encoder.eval()
    with torch.no_grad():
      z_src = self._encoder(data.x, data.edge_index)
      z_tgt = self._encoder(tgt_data.x, tgt_data.edge_index)
      h_src = self.classifier.output(self._encoder(data.x, data.edge_index), data.edge_index)
      h_tgt = self.classifier.output(self._encoder(tgt_data.x, tgt_data.edge_index), tgt_data.edge_index)
      #embed()
    return z_src, z_tgt, h_src, h_tgt

  def test(self, data, test_on_val=False, da=False):
    self.classifier.eval()
    self._encoder.eval()
    #if da:
    out = self.classifier(self._encoder(data.x, data.edge_index), data.edge_index)
    #out = self._model(data.x, data.edge_index)
    if test_on_val:
      pred = out[self._val_mask].detach().cpu().numpy()
    else:
      pred = out[self._test_mask].detach().cpu().numpy()

    pred_best = pred.argmax(-1)
    if test_on_val:
      correct = data.y[self._val_mask].detach().cpu().numpy()
    else:
      correct = data.y[self._test_mask].detach().cpu().numpy()
    n_classes = out.shape[-1]
    pred_onehot = np.zeros((len(pred_best), n_classes))
    pred_onehot[np.arange(pred_best.shape[0]), pred_best] = 1

    correct_onehot = np.zeros((len(correct), n_classes))
    correct_onehot[np.arange(correct.shape[0]), correct] = 1

    results = {
        'accuracy': sklearn.metrics.accuracy_score(correct, pred_best),
        'f1_micro': sklearn.metrics.f1_score(correct, pred_best,
                                                  average='micro'),
        'f1_macro': sklearn.metrics.f1_score(correct, pred_best,
                                                  average='macro'),
        'rocauc_ovr': sklearn.metrics.roc_auc_score(correct_onehot,
                                                         pred_onehot,
                                                         multi_class='ovr'),
        'rocauc_ovo': sklearn.metrics.roc_auc_score(correct_onehot,
                                                         pred_onehot,
                                                         multi_class='ovo'),
        'logloss': sklearn.metrics.log_loss(correct, pred)}
    
    return results

  def train(self, data, tgt_data,
            tuning_metric: str,
            tuning_metric_is_loss: bool, 
            wandb = None, alpha=0.1, beta=0.1):
    losses = []
    best_val_metric = np.inf if tuning_metric_is_loss else -np.inf
    #test_metrics = None
    best_val_metrics = None
    #with torch.no_grad():
    #    print("CMD before train:", self._encoder.cmd(data.x, tgt_data.x))
    # self._val_mask = 

    for i in range(self._epochs):
      #perm = torch.randperm(self.idx_test.shape[0])
      #iid_train = self.idx_test[perm[:self._train_mask.sum()]]
      #self._val_mask = index_to_mask(iid_train, data.num_nodes)
      step_loss = self.train_step(data, tgt_data, alpha, beta)
      losses.append(step_loss)

      val_metrics = self.test(data, test_on_val=True)
      #z_src, z_tgt, h_src, h_tgt = self.get_embeddings(data, tgt_data)
      if wandb is not None:
        wandb.log({'loss_s':step_loss[0], 'loss_t':step_loss[1], 'loss_z':step_loss[2], 'accuracy': val_metrics['accuracy']}) #,  'cmd_x':cmd(data.x, tgt_data.x), 'cmd_z':cmd(z_src, z_tgt), 'cmd_h':cmd(h_src, h_tgt)})
      #print(val_metrics)
      #wandb.log({'loss':losses[-1]})
      if ((tuning_metric_is_loss and val_metrics[tuning_metric] < best_val_metric) or
          (not tuning_metric_is_loss and val_metrics[tuning_metric] > best_val_metric)):
        best_val_metric = val_metrics[tuning_metric]
        #best_val_metrics = copy.deepcopy(val_metrics)
        test_metrics = self.test(data, test_on_val=False)
    #with torch.no_grad():
    #    print("CMD after train:", self._encoder.cmd(data.x, tgt_data.x))
    return losses, test_metrics
