import copy
import random
from functools import wraps
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import Data, Batch
from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj
from torch_geometric.utils.convert import to_networkx, to_scipy_sparse_matrix, from_scipy_sparse_matrix
from torch_scatter import scatter_add
from torch_sparse import coalesce
#from utils import drop_adj, drop_feature
from argparser import args

# helper functions
from scipy.sparse import csr_matrix
import networkx as nx
from scipy.linalg import fractional_matrix_power, inv
import scipy.sparse as sp
from typing import Optional
from supcon_loss import SupConLoss

def default(val, def_val):
    return def_val if val is None else val

def flatten(t):
    return t.reshape(t.shape[0], -1)

def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# augmentation utils

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# exponential moving average

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# MLP class for projector and predictor

class MLP(nn.Module):
    def __init__(self, dim, projection_size, hidden_size = 512):#4096):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_size),
            #TODO nn.BatchNorm1d(hidden_size),
            # nn.ReLU(inplace=True),
            nn.PReLU(),
            nn.Linear(hidden_size, projection_size)
        )

    def forward(self, x):
        return self.net(x)

class IGSD(nn.Module):
    def __init__(self, online_encoder, sup_encoder, feat_dim, hidden_layer = -2, projection_size = 256, \
                 projection_hidden_size = 4096, augment_fn = None, moving_average_decay = 0.99):
        super().__init__()
        self.online_encoder = online_encoder
        self.encoder = sup_encoder
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)
        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
        self.target_encoder = self._get_target_encoder() #None

        self.init_emb()

    def init_emb(self):
      initrange = -1.5 / args.hidden_dim #self.embedding_dim
      for m in self.modules():
          if isinstance(m, nn.Linear):
              torch.nn.init.xavier_uniform_(m.weight.data)
              if m.bias is not None:
                  m.bias.data.fill_(0.0)

    @singleton('target_encoder')
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        return target_encoder

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

    def loss_fn(self, x, y):
        x = F.normalize(x, dim=-1, p=2)
        y = F.normalize(y, dim=-1, p=2)
        return 2 - 2 * (x * y).sum(dim=-1)

    def forward(self, data, mask=None):
        pred = self.encoder(data)
        return pred

    def unsup_loss(self, adj, diff, mask=None):

        #if args.aug_type == 'diff':

        diff_proj, adj_proj = None, None
        diff_proj = self.online_encoder.embed(diff, latent=diff_proj)
        adj_proj =  self.online_encoder.embed(adj, latent=adj_proj)
        online_pred_two = self.online_predictor(diff_proj)  # online_proj_two
        online_pred_one = self.online_predictor(adj_proj)  # online_proj_one

        with torch.no_grad():
            #target_encoder = self._get_target_encoder()
            target_proj_one = self.target_encoder.embed(adj)
            target_proj_two = self.target_encoder.embed(diff)

        loss_one = self.loss_fn(online_pred_one, target_proj_two.detach())
        loss_two = self.loss_fn(online_pred_two, target_proj_one.detach())

        loss = loss_one + loss_two  # [bs, num_node]
        return loss.mean()

    def neg_loss(self, adj, diff=None, mask=None):
        size = len(adj) #.shape[0]
        #idx = torch.randperm(size)
        #idx2 = torch.randperm(size)

        if args.aug_type == 'diff':
            online_proj_one = self.online_encoder.embed(adj)
            online_proj_two = self.online_encoder.embed(diff)

            online_pred_one = self.online_predictor(online_proj_one)
            online_pred_two = self.online_predictor(online_proj_two)

            with torch.no_grad():
                target_encoder = self._get_target_encoder()
                target_proj_one = target_encoder.embed(adj)
                target_proj_two = target_encoder.embed(diff)
        else:
            #TODO
            adj_one, feat_one = self.graph_aug(adj)
            adj_two, feat_two = self.graph_aug(adj)
            online_proj_one, _ = self.online_encoder(adj_one, feat_one)
            online_proj_two, _ = self.online_encoder(adj_two, feat_two)
            with torch.no_grad():
                target_encoder = self._get_target_encoder()
                target_proj_one, _ = target_encoder(adj_one, feat_one)
                target_proj_two, _ = target_encoder(adj_two, feat_two)

        loss1 = (self.loss_fn(online_pred_one.unsqueeze(dim=0), target_proj_two.unsqueeze(dim=1)) / args.alpha).exp().sum(-1)
        loss2 = (self.loss_fn(online_pred_two.unsqueeze(dim=0), target_proj_one.unsqueeze(dim=1)) / args.alpha).exp().sum(-1)
        loss  = (loss1+loss2).log()
        return loss.mean()

    def supcon_loss(self, adj, diff, mask=None):
        diff_proj, adj_proj = None, None
        diff_proj = self.online_encoder.embed(diff, latent=diff_proj)
        adj_proj =  self.online_encoder.embed(adj, latent=adj_proj)
        online_pred_two = self.online_predictor(diff_proj)  # online_proj_two
        online_pred_one = self.online_predictor(adj_proj)  # online_proj_one

        with torch.no_grad():
            target_proj_one = self.target_encoder.embed(adj)
            target_proj_two = self.target_encoder.embed(diff)

        online_pred_one = torch.unsqueeze(online_pred_one, dim=1)
        online_pred_two = torch.unsqueeze(online_pred_two, dim=0)

        loss = self.loss_fn(online_pred_one, online_pred_two)

        return loss.sum() #mean()

    def embed(self, adj):
        return self.online_encoder.embed(adj)
