import numpy as np
import torch
#from datetime import datetime
from abc import ABC, abstractmethod
import os.path as osp
import torch
import logging
import sklearn.metrics
import copy
import wandb
import torch.nn as nn
from IPython import embed
from tqdm import tqdm
import os.path as osp
import numpy as np
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_zip)
from torch_geometric.io import read_txt_array
from torch_geometric.utils import remove_self_loops
from torch_geometric.data import Data, DataLoader



def l2diff(x1, x2):
    """
    standard euclidean norm
    """
    return (x1-x2).norm(p=2)

def moment_diff(sx1, sx2, k):
    """
    difference between moments
    """
    ss1 = sx1.pow(k).mean(0)
    ss2 = sx2.pow(k).mean(0)
    #ss1 = sx1.mean(0)
    #ss2 = sx2.mean(0)
    return l2diff(ss1,ss2)

def cmd(X, X_test, K=5):
    x1 = X
    x2 = X_test
    mx1 = x1.mean(0)
    mx2 = x2.mean(0)
    sx1 = x1 - mx1
    sx2 = x2 - mx2
    dm = l2diff(mx1,mx2)
    scms = [dm]
    for i in range(K-1):
        # moment diff of centralized samples
        scms.append(moment_diff(sx1,sx2,i+2))
        #scms+=moment_diff(sx1,sx2,1)
    return sum(scms)

def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    #dist = torch.mm(x, y_t)
    #Ensure diagonal is zero if x=y
    #if y is None:
    #     dist = dist - torch.diag(dist.diag)
    return torch.clamp(dist, 0.0, np.inf)

def MMD(X, Xtest, alpha=1e0):
    H = torch.exp(- alpha * pairwise_distances(X)) #+ torch.exp(- 1e-1 * pairwise_distances(X)) + torch.exp(- 1e-3 * pairwise_distances(X))
    f = torch.exp(- alpha * pairwise_distances(X, Xtest)) #+ torch.exp(- 1e-1 * pairwise_distances(X, Xtest)) + torch.exp(- 1e-3 * pairwise_distances(X, Xtest))
    z = torch.exp(- alpha * pairwise_distances(Xtest, Xtest)) #+ torch.exp(- 1e-1 * pairwise_distances(Xtest, Xtest)) + torch.exp(- 1e-3 * pairwise_distances(Xtest, Xtest))
    MMD_dist = H.mean() - 2 * f.mean() + z.mean()
    return MMD_dist

class MLP_Encoder(nn.Module):
    def __init__(self,
                 in_feats,
                 hid_feats,
                 dropout=0):
        super(MLP_Encoder, self).__init__()
        self.mlp = nn.ModuleList([nn.Linear(in_feats, hid_feats), nn.Linear(hid_feats, hid_feats)])
        #print(dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.activation = nn.Tanh()
        #self.conv1 = TaxoHeteroGraphConv(conv1_funcs, aggregate='sum')

    def forward(self, features, edges=None):
        # inputs are features of nodes
        x = features
        for idx,layer in enumerate(self.mlp):
          if idx < len(self.mlp) - 1:
            x = self.dropout(self.activation(layer(x)))
          else:
            x = self.dropout(layer(x))
        return x
    
    def reset_parameters(self):
      for layer in self.mlp:
          layer.reset_parameters()

    def dann_output(self, idx_train, iid_train, alpha=1):
      pass
        #reverse_feature = ReverseLayerF.apply(self.h, alpha)
        #dann_loss = xent(self.disc(self.g, reverse_feature)[idx_train,:], torch.ones_like(labels[idx_train])).mean() + xent(self.disc(self.g, reverse_feature)[iid_train,:], torch.zeros_like(labels[iid_train])).mean()
        #return dann_loss
    
    def cmd(self, h_src, h_tgt, alpha = 0.1):
        return alpha * cmd(h_src, h_tgt)

class Identity_Encoder(nn.Module):
    def __init__(self):
        super(Identity_Encoder, self).__init__()
    def forward(self, features):
      return features
    
    def reset_parameters(self):
      return

def index_to_mask(index, size):
    mask = torch.zeros(size, dtype=torch.bool, device=index.device)
    mask[index] = 1
    return mask

def random_planetoid_splits(data, num_classes, percls_trn=20, val_lb=500, Flag=0):
    # Set new random planetoid splits:
    # * round(train_rate*len(data)/num_classes) * num_classes labels for training
    # * val_rate*len(data) labels for validation
    # * rest labels for testing

    indices = []
    for i in range(num_classes):
        index = (data.y == i).nonzero().view(-1)
        index = index[torch.randperm(index.size(0))]
        indices.append(index)

    train_index = torch.cat([i[:percls_trn] for i in indices], dim=0)

    if Flag is 0:
        rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0)
        rest_index = rest_index[torch.randperm(rest_index.size(0))]

        data.train_mask = index_to_mask(train_index, size=data.num_nodes)
        data.val_mask = index_to_mask(rest_index[:val_lb], size=data.num_nodes)
        data.test_mask = index_to_mask(
            rest_index[val_lb:], size=data.num_nodes)
    else:
        val_index = torch.cat([i[percls_trn:percls_trn+val_lb]
                               for i in indices], dim=0)
        rest_index = torch.cat([i[percls_trn+val_lb:] for i in indices], dim=0)
        rest_index = rest_index[torch.randperm(rest_index.size(0))]

        data.train_mask = index_to_mask(train_index, size=data.num_nodes)
        data.val_mask = index_to_mask(val_index, size=data.num_nodes)
        data.test_mask = index_to_mask(rest_index, size=data.num_nodes)
    return data

def translate_pq_to_lambda(p_q, d):
    return (p_q-1) / (p_q+1) * np.sqrt(d)

# u-v = 

def ContextualSBM(n, d, p_q, p, mu, train_percent=0.01, u=None, delta=0.0, Z=None):
    # n = 800 #number of nodes
    # d = 5 # average degree
    # Lambda = 1 # parameters
    # p = 1000 # feature dim
    # mu = 1 # mean of Gaussian
    gamma = n/p
    Lambda = translate_pq_to_lambda(p_q,d)
    c_in = d + np.sqrt(d)*Lambda
    c_out = d - np.sqrt(d)*Lambda
    y = np.ones(n)
    y[int(n/2)+1:] = -1
    y = np.asarray(y, dtype=int)

    # creating edge_index
    edge_index = [[], []]
    for i in range(n-1):
        for j in range(i+1, n):
            if y[i]*y[j] > 0:
                Flip = np.random.binomial(1, c_in/n)
            else:
                Flip = np.random.binomial(1, c_out/n)
            if Flip > 0.5:
                edge_index[0].append(i)
                edge_index[1].append(j)
                edge_index[0].append(j)
                edge_index[1].append(i)

    # creating node features
    x = np.zeros([n, p])
    if Z is None:
      Z = np.random.normal(0, 1, [n, p])
    for i in range(n):
        x[i] = np.sqrt(mu/n)*(y[i]*u + u * delta) + Z[i]/np.sqrt(p)
      
    #if delta is not None:
    #  x += delta
    data = Data(x=torch.tensor(x, dtype=torch.float32),
                edge_index=torch.tensor(edge_index),
                y=torch.tensor((y + 1) // 2, dtype=torch.int64))
    # order edge list and remove duplicates if any.
    data.coalesce()

    num_class = len(np.unique(y))
    val_lb = int(n * train_percent)
    percls_trn = int(round(train_percent * n / num_class))
    data = random_planetoid_splits(data, num_class, percls_trn, val_lb)

    # add parameters to attribute
    data.Lambda = Lambda
    data.mu = mu
    data.n = n
    data.p = p
    data.d = d
    data.train_percent = train_percent
    data.p_q = c_in / c_out
    data.u = u
    data.delta = delta
    #data.cos_u = np.arccos(cos_u)/np.pi*180
    return data

class Benchmarker(ABC):

  def __init__(self, generator_config,
               model_class=None, benchmark_params=None, h_params=None):
    self._model_name = model_class.__name__ if model_class is not None else ''
    self._model_class = model_class
    self._benchmark_params = benchmark_params
    self._h_params = h_params
    self.AdjustParams(generator_config)

  # Override this function if the input data affects the model architecture.
  # See NNNodeBenchmarker for an example implementation.
  def AdjustParams(self, generator_config):
    pass

  def GetModelName(self):
    return self._model_name

  # Train and test the model.
  # Arguments:
  #   * element: output of the 'Convert to torchgeo' beam stage.
  #   * output_path: where to save logs and data.
  # Returns:
  #   * named dict with keys/vals:
  #      'losses': iterable of loss values over the epochs.
  #      'test_metrics': dict of named test metrics for the benchmark run.
  @abstractmethod
  def Benchmark(self, element,
                test_on_val: bool = False,
                tuning_metric: str = None,
                tuning_metric_is_loss: bool = False):
    del element  # unused
    del test_on_val  # unused
    del tuning_metric  # unused
    del tuning_metric_is_loss  # unused
    return {'losses': [], 'test_metrics': {}}

# general benchmarkers
class NNNodeBenchmarker():
  def __init__(self, arch, model_class, benchmark_params, h_params, device=None):
    #super().__init__(generator_config, model_class, benchmark_params, h_params)
    # remove meta entries from h_params
    self._epochs = benchmark_params['epochs']
    if True:
      if arch == 'I_GCN':
        self._encoder = Identity_Encoder()
        self._model = model_class(**h_params)
      else:
        self._encoder = MLP_Encoder(h_params['in_channels'], h_params['hidden_channels'], h_params['dropout'])
        h_params['in_channels'] = h_params['hidden_channels']
        self._model = model_class(**h_params)
      
      
    else:
      pass
      #self._encoder = 
    if device is not None:
      self._encoder = self._encoder.to(device)
      self._model = self._model.to(device)
    # TODO(palowitch): make optimizer configurable.
    self._lr = benchmark_params['lr']
    self._criterion = torch.nn.CrossEntropyLoss()
    self._train_mask = None
    self._val_mask = None
    self._test_mask = None

  def AdjustParams(self, generator_config):
    if 'num_clusters' in generator_config and self._h_params is not None:
      self._h_params['out_channels'] = generator_config['num_clusters']

  def SetMasks(self, train_mask, val_mask, test_mask):
    self._train_mask = train_mask
    self._val_mask = val_mask
    self._test_mask = test_mask

  def train_step(self, data, tgt_data = None):
    self._model.train()
    self._encoder.train()
    self._optimizer.zero_grad()  # Clear gradients.
    #
    if tgt_data is None:
        out = self._model(self._encoder(data.x), data.edge_index)  # Perform a single forward pass.
        loss = self._criterion(out[self._train_mask],
                           data.y[self._train_mask])  # Compute the loss solely based on the training nodes.
    else:
        #h_src = torch.cat([self._encoder(data.x), self._model.output(self._encoder(data.x), data.edge_index)], dim=-1)
        #h_tgt = torch.cat([self._encoder(tgt_data.x),self._model.output(self._encoder(tgt_data.x), tgt_data.edge_index)], dim=-1)
        z_src = self._encoder(data.x)
        z_tgt = self._encoder(tgt_data.x)
        h_src = self._model.output(self._encoder(data.x), data.edge_index) #[self._train_mask]
        h_tgt = self._model.output(self._encoder(tgt_data.x), tgt_data.edge_index) #[self._val_mask]
        out = self._model(self._encoder(data.x), data.edge_index)  # Perform a single forward pass.
        # 0.1 * cmd(z_src, z_tgt)
        loss = self._criterion(out[self._train_mask], data.y[self._train_mask]) +  0.1 * cmd(h_src, h_tgt)
        #embed()
    #embed()
    loss.backward()  # Derive gradients.
    self._optimizer.step()  # Update parameters based on gradients.
    return loss
  
  def get_embeddings(self, data, tgt_data):
    self._model.eval()
    with torch.no_grad():
      z_src = self._encoder(data.x)
      z_tgt = self._encoder(tgt_data.x)
      h_src = self._model.output(self._encoder(data.x), data.edge_index)
      h_tgt = self._model.output(self._encoder(tgt_data.x), tgt_data.edge_index)
      #embed()
    return z_src, z_tgt, h_src, h_tgt

  def test(self, data, test_on_val=False, da=False):
    self._model.eval()
    self._encoder.eval()
    #if da:
    out = self._model(self._encoder(data.x), data.edge_index)
    #else:
    #  out = self._model(data.x, data.edge_index)
    if test_on_val:
      pred = out[self._val_mask].detach().cpu().numpy()
    else:
      pred = out[self._test_mask].detach().cpu().numpy()

    pred_best = pred.argmax(-1)
    if test_on_val:
      correct = data.y[self._val_mask].cpu().numpy()
    else:
      correct = data.y[self._test_mask].cpu().numpy()
    n_classes = out.shape[-1]
    pred_onehot = np.zeros((len(pred_best), n_classes))
    pred_onehot[np.arange(pred_best.shape[0]), pred_best] = 1

    correct_onehot = np.zeros((len(correct), n_classes))
    correct_onehot[np.arange(correct.shape[0]), correct] = 1

    results = {
        'accuracy': sklearn.metrics.accuracy_score(correct, pred_best),
        'f1_micro': sklearn.metrics.f1_score(correct, pred_best,
                                                  average='micro'),
        'f1_macro': sklearn.metrics.f1_score(correct, pred_best,
                                                  average='macro'),
        'rocauc_ovr': sklearn.metrics.roc_auc_score(correct_onehot,
                                                         pred_onehot,
                                                         multi_class='ovr'),
        'rocauc_ovo': sklearn.metrics.roc_auc_score(correct_onehot,
                                                         pred_onehot,
                                                         multi_class='ovo'),
        'logloss': sklearn.metrics.log_loss(correct, pred)}
    
    return results

  def reset_parameters(self):
    self._encoder.reset_parameters()
    self._model.reset_parameters()
    self._optimizer = torch.optim.Adam(list(self._model.parameters()) + list(self._encoder.parameters()),
                                       lr=self._lr,
                                       weight_decay=5e-4)

  def train(self, data, tgt_data,
            tuning_metric: str,
            tuning_metric_is_loss: bool,
            wandb = None):
    losses = []
    best_val_metric = np.inf if tuning_metric_is_loss else -np.inf
    #test_metrics = None
    best_val_metrics = None
    #with torch.no_grad():
    #    print("CMD before train:", self._encoder.cmd(data.x, tgt_data.x))
    for i in range(self._epochs):
      losses.append(float(self.train_step(data, tgt_data)))
      val_metrics = self.test(data, test_on_val=True, da=True if tgt_data is not None else False)
      #print('Epoch: %d   Loss: %.4f   Accuracy: %.2f' % (i, losses[-1], val_metrics['accuracy']))
      if wandb is not None:
        #z_src, z_tgt, h_src, h_tgt = self.get_embeddings(data, tgt_data)
        wandb.log({'loss':losses[-1], 'accuracy': val_metrics['accuracy']}) #'cmd_x':cmd(data.x, tgt_data.x), 'cmd_z':cmd(z_src, z_tgt), 'cmd_h':cmd(h_src, h_tgt)})
      if ((tuning_metric_is_loss and val_metrics[tuning_metric] < best_val_metric) or
          (not tuning_metric_is_loss and val_metrics[tuning_metric] > best_val_metric)):
        best_val_metric = val_metrics[tuning_metric]
        best_val_metrics = copy.deepcopy(val_metrics)
        #test_metrics = self.test(data, test_on_val=False)
    #with torch.no_grad():
    #    print("CMD after train:", self._encoder.cmd(data.x, tgt_data.x))
    return losses, best_val_metrics

  def Benchmark(self, element,
                tuning_metric: str = None,
                tuning_metric_is_loss: bool = False):
    torch_data = element['torch_data']
    masks = element['masks']
    skipped = element['skipped']
    sample_id = element['sample_id']

    out = {
      'skipped': skipped,
      'results': None
    }
    out.update(element)
    out['losses'] = None
    out['val_metrics'] = {}
    out['test_metrics'] = []

    if skipped:
      logging.info(f'Skipping benchmark for sample id {sample_id}')
      return out

    train_mask, val_mask, test_mask = masks[0]

    self.SetMasks(train_mask, val_mask, test_mask)

    val_metrics = {}
    test_metrics = []
    losses = None
    try:
        losses, val_metrics = self.train(torch_data[0], tuning_metric=tuning_metric, tuning_metric_is_loss=tuning_metric_is_loss)
        for (data_i, mask_i) in zip(torch_data, masks):
            train_mask, val_mask, test_mask = mask_i
            self.SetMasks(train_mask, val_mask, test_mask)
            test_res = self.test(data_i, test_on_val=False) 
            test_metrics.append(test_res)
    except Exception as e:
      logging.info(f'Failed to run for sample id {sample_id}')
      out['skipped'] = True

    out['losses'] = losses
    for i in test_metrics:
        out['test_metrics'].append({})
        out['test_metrics'][-1].update(i)
    out['val_metrics'].update(val_metrics)
    return out

class DomainDataset(InMemoryDataset):
    r"""The protein-protein interaction networks from the `"Predicting
    Multicellular Function through Multi-layer Tissue Networks"
    <https://arxiv.org/abs/1707.04638>`_ paper, containing positional gene
    sets, motif gene sets and immunological signatures as features (50 in
    total) and gene ontology sets as labels (121 in total).
    Args:
        root (string): Root directory where the dataset should be saved.
        split (string): If :obj:`"train"`, loads the training dataset.
            If :obj:`"val"`, loads the validation dataset.
            If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """
    def __init__(self,
                 root,
                 name,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None):
        self.name = name
        #self.root = root
        super(DomainDataset, self).__init__(root, transform, pre_transform, pre_filter)

        self.data, self.slices = torch.load(self.processed_paths[0])


    @property
    def raw_file_names(self):
        return ["docs.txt", "edgelist.txt", "labels.txt"]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        pass

    def process(self):

        edge_path = osp.join(self.raw_dir, '{}_edgelist.txt'.format(self.name))
        edge_index = read_txt_array(edge_path, sep=',', dtype=torch.long).t()

        docs_path = osp.join(self.raw_dir, '{}_docs.txt'.format(self.name))
        f = open(docs_path, 'rb')
        content_list = []
        for line in f.readlines():
            line = str(line, encoding="utf-8")
            content_list.append(line.split(","))
        x = np.array(content_list, dtype=float)
        x = torch.from_numpy(x).to(torch.float)

        label_path = osp.join(self.raw_dir, '{}_labels.txt'.format(self.name))
        f = open(label_path, 'rb')
        content_list = []
        for line in f.readlines():
            line = str(line, encoding="utf-8")
            line = line.replace("\r", "").replace("\n", "")
            content_list.append(line)
        y = np.array(content_list, dtype=int)
        y = torch.from_numpy(y).to(torch.int64)

        data_list = []
        data = Data(edge_index=edge_index, x=x, y=y)

        random_node_indices = np.random.permutation(y.shape[0])
        training_size = int(len(random_node_indices) * 0.7)
        val_size = int(len(random_node_indices) * 0.1)
        train_node_indices = random_node_indices[:training_size]
        val_node_indices = random_node_indices[training_size:training_size + val_size]
        test_node_indices = random_node_indices[training_size + val_size:]

        train_masks = torch.zeros([y.shape[0]], dtype=torch.uint8)
        train_masks[train_node_indices] = 1
        val_masks = torch.zeros([y.shape[0]], dtype=torch.uint8)
        val_masks[val_node_indices] = 1
        test_masks = torch.zeros([y.shape[0]], dtype=torch.uint8)
        test_masks[test_node_indices] = 1

        data.train_mask = train_masks
        data.val_mask = val_masks
        data.test_mask = test_masks


        if self.pre_transform is not None:
            data = self.pre_transform(data)

        data_list.append(data)

        data, slices = self.collate([data])

        torch.save((data, slices), self.processed_paths[0])