
# Two classifier's version
# X, P(Y|X)
# X, Z, P(Y|Z)
import torch
import torch.nn as nn
import logging
import sklearn.metrics
import copy
import numpy as np
import torch.nn.functional as F
from IPython import embed
from utils import index_to_mask
from tllib.modules.grl import WarmStartGradientReverseLayer
from torch.autograd import Function
import ot
from typing import Optional
from tllib.utils.metric import binary_accuracy, accuracy

class RandomizedMultiLinearMap(nn.Module):
    """Random multi linear map
    Given two inputs :math:`f` and :math:`g`, the definition is
    .. math::
        T_{\odot}(f,g) = \dfrac{1}{\sqrt{d}} (R_f f) \odot (R_g g),
    where :math:`\odot` is element-wise product, :math:`R_f` and :math:`R_g` are random matrices
    sampled only once and ﬁxed in training.
    Args:
        features_dim (int): dimension of input :math:`f`
        num_classes (int): dimension of input :math:`g`
        output_dim (int, optional): dimension of output tensor. Default: 1024
    Shape:
        - f: (minibatch, features_dim)
        - g: (minibatch, num_classes)
        - Outputs: (minibatch, output_dim)
    """

    def __init__(self, features_dim: int, num_classes: int, output_dim: Optional[int] = 1024):
        super(RandomizedMultiLinearMap, self).__init__()
        self.Rf = torch.randn(features_dim, output_dim)
        self.Rg = torch.randn(num_classes, output_dim)
        self.output_dim = output_dim

    def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
        f = torch.mm(f, self.Rf.to(f.device))
        g = torch.mm(g, self.Rg.to(g.device))
        output = torch.mul(f, g) / np.sqrt(float(self.output_dim))
        return output


class MultiLinearMap(nn.Module):
    """Multi linear map
    Shape:
        - f: (minibatch, F)
        - g: (minibatch, C)
        - Outputs: (minibatch, F * C)
    """

    def __init__(self):
        super(MultiLinearMap, self).__init__()

    def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
        batch_size = f.size(0)
        output = torch.bmm(g.unsqueeze(2), f.unsqueeze(1))
        return output.view(batch_size, -1)

class Identity_Encoder(nn.Module):
    def __init__(self):
        super(Identity_Encoder, self).__init__()
    def forward(self, features, edges=None):
      return features
    
    def reset_parameters(self):
      return

class MLP_Encoder(nn.Module):
    def __init__(self,
                 in_feats,
                 hid_feats,
                 dropout=0):
        super(MLP_Encoder, self).__init__()
        self.mlp = nn.ModuleList([nn.Linear(in_feats, hid_feats), nn.Linear(hid_feats, hid_feats)])
        #print(dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.activation = nn.Tanh()
        #self.conv1 = TaxoHeteroGraphConv(conv1_funcs, aggregate='sum')

    def forward(self, features, edges=None):
        # inputs are features of nodes
        x = features
        for idx,layer in enumerate(self.mlp):
          if idx < len(self.mlp) - 1:
            x = self.dropout(self.activation(layer(x)))
          else:
            x = self.dropout(layer(x))
        return x
    
    def reset_parameters(self):
      for layer in self.mlp:
          layer.reset_parameters()

    def dann_output(self, idx_train, iid_train, alpha=1):
      pass
        #reverse_feature = ReverseLayerF.apply(self.h, alpha)
        #dann_loss = xent(self.disc(self.g, reverse_feature)[idx_train,:], torch.ones_like(labels[idx_train])).mean() + xent(self.disc(self.g, reverse_feature)[iid_train,:], torch.zeros_like(labels[iid_train])).mean()
        #return dann_loss
    
    def cmd(self, h_src, h_tgt, alpha = 0.1):
        return alpha * cmd(h_src, h_tgt)

# general benchmarkers
class NNNodeBenchmarker_CDAN():
  def __init__(self, arch, model_class, benchmark_params, h_params, device=None):
    #super().__init__(generator_config, model_class, benchmark_params, h_params)
    # remove meta entries from h_params
    self._epochs = benchmark_params['epochs']
    self.domain_discriminator = nn.Sequential(
        nn.Linear(h_params['hidden_channels'] * h_params['out_channels'], h_params['hidden_channels']),
        #nn.Linear(1024, h_params['hidden_channels']),
        #nn.Linear(64, 2),
        #nn.BatchNorm1d(hidden_size),
        #nn.ReLU(),
        #nn.Linear(h_params['hidden_channels'], h_params['hidden_channels']),
        #nn.BatchNorm1d(hidden_size),
        nn.ReLU(),
        nn.Linear(h_params['hidden_channels'], 2),
    )

    # graph as classifier
    if True:
        if arch == 'I_GCN' or arch == 'I-GCN':
            self._encoder = Identity_Encoder()
            self.classifier = model_class(**h_params)
        # by default using this version
        else:
            self._encoder = MLP_Encoder(h_params['in_channels'], h_params['hidden_channels'], h_params['dropout'])
            h_params['in_channels'] = h_params['hidden_channels']
            self.classifier = model_class(**h_params)
            
    else:
        self._encoder = model_class(**h_params)
        #self._encoder = nn.Identity()
        self.classifier = MLP_Encoder(h_params['out_channels'], 2)

    if device is not None:
      self._encoder = self._encoder.to(device)
      self.classifier = self.classifier.to(device)
      self.domain_discriminator = self.domain_discriminator.to(device)
    self.map = MultiLinearMap()
    #self.map = RandomizedMultiLinearMap(h_params['hidden_channels'], h_params['out_channels'], 64)
    self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True)
    self._optimizer = torch.optim.Adam(list(self._encoder.parameters())+list(self.classifier.parameters())+list(self.domain_discriminator.parameters()),
                                       lr=benchmark_params['lr'],
                                       weight_decay=5e-4)
    self._lr = benchmark_params['lr']
    self._criterion = torch.nn.CrossEntropyLoss()
    self._train_mask = None
    self._val_mask = None
    self._test_mask = None
    self._num_class = h_params['out_channels']
  def AdjustParams(self, generator_config):
    if 'num_clusters' in generator_config and self._h_params is not None:
      self._h_params['out_channels'] = generator_config['num_clusters']
  
  def reset_parameters(self):
    self._encoder.reset_parameters()
    self.classifier.reset_parameters()
    self._optimizer = torch.optim.Adam(list(self._encoder.parameters())+list(self.classifier.parameters())+list(self.domain_discriminator.parameters()),
                                       lr=self._lr,
                                       weight_decay=5e-4)

  def SetMasks(self, train_mask, val_mask, test_mask):
    self._train_mask = train_mask
    self._val_mask = val_mask
    self._test_mask = test_mask
  
  def set_lambda(self, lambd):
    self.lambd = lambd

  def reset_grad(self):
    self._optimizer.zero_grad()  # Clear feature generator gradients.

  def cdan(self, z_s, z_t, out_s, out_t):
    f = torch.cat((z_s, z_t), dim=0)
    g = torch.cat((out_s, out_t), dim=0)
    #embed()
    g = F.softmax(g, dim=1).detach()
    h = self.grl(self.map(f, g))
    d = self.domain_discriminator(h)
    d_label = torch.cat((
        torch.ones((z_s.size(0), )).to(z_s.device),
        torch.zeros((z_t.size(0), )).to(z_t.device),
    )).long()
    
    return F.cross_entropy(d, d_label, reduction='mean'), accuracy(d, d_label)
  
  def dann(self, z_s, z_t):
    f = torch.cat((z_s, z_t), dim=0)
    h = self.grl(f)
    d = self.domain_discriminator(h)
    d_label = torch.cat((
        torch.ones((z_s.size(0), )).to(z_s.device),
        torch.zeros((z_t.size(0), )).to(z_t.device),
    )).long()
    
    return F.cross_entropy(d, d_label, reduction='mean'), accuracy(d, d_label)
  
  def train_step(self, data, tgt_data, alpha=0.1):
    self.classifier.train()
    self._encoder.train()
    self.reset_grad()
    Z = self._encoder(data.x, data.edge_index)
    Z_t = self._encoder(tgt_data.x, tgt_data.edge_index)

    
    #out_s = self.classifier(Z, data.edge_index)
    out_t = self.classifier(Z_t, tgt_data.edge_index)

    out_s1 = self.classifier(Z, data.edge_index)  # Perform a single forward pass.
    loss_s1 = self._criterion(out_s1[self._train_mask], data.y[self._train_mask])
    
    Z = self.classifier.output(Z, data.edge_index)
    Z_t = self.classifier.output(Z_t, tgt_data.edge_index)
    
    loss_t, acc_t = self.cdan(Z[self._train_mask], Z_t[self._val_mask], out_s1[self._train_mask], out_t[self._val_mask])    
    #loss_t, acc_t = self.dann(Z[self._train_mask], Z_t[self._val_mask])    
    
    #loss_t = self._criterion(out_t, tgt_data.y)
    loss = loss_s1 + alpha * loss_t
    loss.backward()
    #embed()
    self._optimizer.step()
    return loss_s1.item(), loss_t.item(), acc_t
  
  def get_embeddings(self, data, tgt_data):
    self.classifier.eval()
    self._encoder.eval()
    with torch.no_grad():
      z_src = self._encoder(data.x, data.edge_index)
      z_tgt = self._encoder(tgt_data.x, tgt_data.edge_index)
      h_src = self.classifier.output(self._encoder(data.x, data.edge_index), data.edge_index)
      h_tgt = self.classifier.output(self._encoder(tgt_data.x, tgt_data.edge_index), tgt_data.edge_index)
      #embed()
    return z_src, z_tgt, h_src, h_tgt

  def test(self, data, test_on_val=False, da=False):
    self.classifier.eval()
    self._encoder.eval()
    #if da:
    out = self.classifier(self._encoder(data.x, data.edge_index), data.edge_index)
    #out = self._model(data.x, data.edge_index)
    if test_on_val:
      pred = out[self._val_mask].detach().cpu().numpy()
    else:
      pred = out[self._test_mask].detach().cpu().numpy()

    pred_best = pred.argmax(-1)
    if test_on_val:
      correct = data.y[self._val_mask].detach().cpu().numpy()
    else:
      correct = data.y[self._test_mask].detach().cpu().numpy()
    n_classes = out.shape[-1]
    pred_onehot = np.zeros((len(pred_best), n_classes))
    pred_onehot[np.arange(pred_best.shape[0]), pred_best] = 1

    correct_onehot = np.zeros((len(correct), n_classes))
    correct_onehot[np.arange(correct.shape[0]), correct] = 1

    results = {
        'accuracy': sklearn.metrics.accuracy_score(correct, pred_best),
        'f1_micro': sklearn.metrics.f1_score(correct, pred_best,
                                                  average='micro'),
        'f1_macro': sklearn.metrics.f1_score(correct, pred_best,
                                                  average='macro'),
        'rocauc_ovr': sklearn.metrics.roc_auc_score(correct_onehot,
                                                         pred_onehot,
                                                         multi_class='ovr'),
        'rocauc_ovo': sklearn.metrics.roc_auc_score(correct_onehot,
                                                         pred_onehot,
                                                         multi_class='ovo'),
        'logloss': sklearn.metrics.log_loss(correct, pred)}
    
    return results

  def train(self, data, tgt_data,
            tuning_metric: str,
            tuning_metric_is_loss: bool, 
            wandb = None):
    losses = []
    best_val_metric = np.inf if tuning_metric_is_loss else -np.inf
    #test_metrics = None
    best_val_metrics = None
    #with torch.no_grad():
    #    print("CMD before train:", self._encoder.cmd(data.x, tgt_data.x))
    # self._val_mask = 

    for i in range(self._epochs):
      perm = torch.randperm(self.idx_test.shape[0])
      iid_train = self.idx_test[perm[:self._train_mask.sum()]]
      self._val_mask = index_to_mask(iid_train, data.num_nodes)
      step_loss = self.train_step(data, tgt_data)
      losses.append(step_loss)

      val_metrics = self.test(data, test_on_val=False)
      #z_src, z_tgt, h_src, h_tgt = self.get_embeddings(data, tgt_data)
      if wandb is not None:
        wandb.log({'loss_s':step_loss[0], 'loss_t':step_loss[1], 'loss_z':step_loss[2], 'accuracy': val_metrics['accuracy']}) #,  'cmd_x':cmd(data.x, tgt_data.x), 'cmd_z':cmd(z_src, z_tgt), 'cmd_h':cmd(h_src, h_tgt)})
      #print(val_metrics)
      #wandb.log({'loss':losses[-1]})
      if ((tuning_metric_is_loss and val_metrics[tuning_metric] < best_val_metric) or
          (not tuning_metric_is_loss and val_metrics[tuning_metric] > best_val_metric)):
        best_val_metric = val_metrics[tuning_metric]
        best_val_metrics = copy.deepcopy(val_metrics)
        #test_metrics = self.test(data, test_on_val=False)
    #with torch.no_grad():
    #    print("CMD after train:", self._encoder.cmd(data.x, tgt_data.x))
        #print(step_loss[2][0].item())
        #embed()
    return losses, best_val_metrics
