import torch
import torch.nn as nn
import sklearn.metrics
import numpy as np
import torch.nn.functional as F
import ot



def L2_dist(x,y):
    distx = torch.reshape(torch.sum(torch.square(x),1), (-1,1))
    disty = torch.reshape(torch.sum(torch.square(y),1), (1,-1))
    dist = distx + disty
    dist -= 2.0*torch.matmul(x, torch.transpose(y,0,1))  
    return dist

class Identity_Encoder(nn.Module):
    def __init__(self):
        super(Identity_Encoder, self).__init__()
    def forward(self, features, edges=None):
      return features
    
    def reset_parameters(self):
      return

class MLP_Encoder(nn.Module):
    def __init__(self,
                 in_feats,
                 hid_feats,
                 dropout=0):
        super(MLP_Encoder, self).__init__()
        self.mlp = nn.ModuleList([nn.Linear(in_feats, hid_feats), nn.Linear(hid_feats, hid_feats)])
        #print(dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.activation = nn.Tanh()
        #self.conv1 = TaxoHeteroGraphConv(conv1_funcs, aggregate='sum')

    def forward(self, features, edges=None):
        # inputs are features of nodes
        x = features
        for idx,layer in enumerate(self.mlp):
          if idx < len(self.mlp) - 1:
            x = self.dropout(self.activation(layer(x)))
          else:
            x = self.dropout(layer(x))
        return x
    
    def reset_parameters(self):
      for layer in self.mlp:
          layer.reset_parameters()


# general benchmarkers
class NNNodeBenchmarker_GDOT():
  def __init__(self, arch, model_class, benchmark_params, h_params, device=None):
    self._epochs = benchmark_params['epochs']
    self.device = device
    self.arch = arch
    if arch == 'I_GCN':
        self._encoder = Identity_Encoder()
        self.classifier = model_class(**h_params)
    else:
        self._encoder = MLP_Encoder(h_params['in_channels'], h_params['hidden_channels'], h_params['dropout'])
        h_params['in_channels'] = h_params['hidden_channels']
        self.classifier = model_class(**h_params)


    if device is not None:
      self._encoder = self._encoder.to(device)
      self._model = self.classifier.to(device)

    self._optimizer = torch.optim.Adam(list(self._encoder.parameters())+list(self.classifier.parameters()),
                                       lr=benchmark_params['lr'],
                                       weight_decay=5e-4)
    self._lr = benchmark_params['lr']
    self._criterion = torch.nn.CrossEntropyLoss(reduction="none")
    self._train_mask = None
    self._val_mask = None
    self._test_mask = None
    self._num_class = h_params['out_channels']

  def reset_parameters(self):
    self._encoder.reset_parameters()
    self.classifier.reset_parameters()
    self._optimizer = torch.optim.AdamW(list(self.classifier.parameters()) + list(self._encoder.parameters()),
                                       lr=self._lr,
                                       weight_decay=5e-4)
    
  def SetMasks(self, train_mask, val_mask, test_mask):
    self._train_mask = train_mask
    self._val_mask = val_mask
    self._test_mask = test_mask

  def reset_grad(self):
    self._optimizer.zero_grad()  # Clear feature generator gradients.

  def train_batch(self, batch, tgt_batch, wandb=None, alpha=0.1, beta=0.5):
    self.classifier.eval()
    self._encoder.eval()
    with torch.no_grad():
        Z = self._encoder(batch.x, batch.edge_index)
        Z_t = self._encoder(tgt_batch.x, tgt_batch.edge_index)
        
        out_s = self.classifier(Z, batch.edge_index)
        out_t = self.classifier(Z_t, tgt_batch.edge_index)
        C0 = torch.cdist(Z[batch.train_mask], Z_t[tgt_batch.test_mask], p=2.0)**2
        y_true = batch.y
        ys_cat = torch.eye(self._num_class, dtype=torch.int8)[y_true].to(batch.y.device)
        C1 = torch.cdist(ys_cat[batch.train_mask].float(), F.softmax(out_t[tgt_batch.test_mask],-1), p=2)**2
        C= alpha*C0 +beta* C1
        gamma=ot.emd(torch.FloatTensor(ot.unif(batch.train_mask.sum().item())), torch.FloatTensor(ot.unif(tgt_batch.test_mask.sum().item())),C).to(batch.y.device)

    self.classifier.train()
    self._encoder.train()
    self.reset_grad()
    Z = self._encoder(batch.x, batch.edge_index)
    Z_t = self._encoder(tgt_batch.x, tgt_batch.edge_index)
    
    out_t = self.classifier(Z_t, tgt_batch.edge_index)
    yt_pred = F.log_softmax(out_t,-1)
    y_true = batch.y
    ys_cat = torch.eye(self._num_class, dtype=torch.int8)[y_true].to(batch.y.device)
    loss_t = -torch.matmul(ys_cat[batch.train_mask].float(), torch.transpose(yt_pred[tgt_batch.test_mask],1,0))
    out_s1 = self.classifier(Z, batch.edge_index) 
    
    loss_s1 = self._criterion(out_s1[batch.train_mask], batch.y[batch.train_mask])
    loss_s1 = (loss_s1 * batch.node_norm[batch.train_mask]).sum()
    loss_t = torch.sum(gamma * loss_t * tgt_batch.node_norm[tgt_batch.test_mask])
    loss_z = torch.sum(gamma* L2_dist(Z[batch.train_mask], Z_t[tgt_batch.test_mask]) * tgt_batch.node_norm[tgt_batch.test_mask] )
    loss = loss_s1 + beta * loss_t + alpha * loss_z 
    loss.backward()
    self._optimizer.step()
    if wandb is not None:
        wandb.log({'loss_s1':loss_s1.item(), 'loss_z':loss_z.item(), 'loss_t':loss_t.item()})
    return loss.item()
  
  def test(self, data, test_on_val=False, da=False):
    self.classifier.eval()
    self._encoder.eval()
    out = self.classifier(self._encoder(data.x, data.edge_index), data.edge_index)
    if test_on_val:
      pred = out[self._val_mask].detach().cpu().numpy()
    else:
      pred = out[self._test_mask].detach().cpu().numpy()

    pred_best = pred.argmax(-1)
    if test_on_val:
      correct = data.y[self._val_mask].detach().cpu().numpy()
    else:
      correct = data.y[self._test_mask].detach().cpu().numpy()
    n_classes = out.shape[-1]
    pred_onehot = np.zeros((len(pred_best), n_classes))
    pred_onehot[np.arange(pred_best.shape[0]), pred_best] = 1

    correct_onehot = np.zeros((len(correct), n_classes))
    correct_onehot[np.arange(correct.shape[0]), correct] = 1

    results = {
        'accuracy': sklearn.metrics.accuracy_score(correct, pred_best),
        'f1_micro': sklearn.metrics.f1_score(correct, pred_best,
                                                  average='micro'),
        'f1_macro': sklearn.metrics.f1_score(correct, pred_best,
                                                  average='macro'),
        'rocauc_ovr': sklearn.metrics.roc_auc_score(correct_onehot,
                                                         pred_onehot,
                                                         multi_class='ovr'),
        'rocauc_ovo': sklearn.metrics.roc_auc_score(correct_onehot,
                                                         pred_onehot,
                                                         multi_class='ovo'),
        'logloss': sklearn.metrics.log_loss(correct, pred)}
    
    return results
  
