import copy
import random
from functools import wraps
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import Data, Batch
from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj
from torch_geometric.utils.convert import to_networkx, to_scipy_sparse_matrix, from_scipy_sparse_matrix
from torch_scatter import scatter_add
from torch_sparse import coalesce
from utils import drop_adj, drop_feature
from semi.argparser import args

# helper functions
from scipy.sparse import csr_matrix
import networkx as nx
from scipy.linalg import fractional_matrix_power, inv
import scipy.sparse as sp
from typing import Optional

def default(val, def_val):
    return def_val if val is None else val

def flatten(t):
    return t.reshape(t.shape[0], -1)

def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# augmentation utils

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# exponential moving average

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# MLP class for projector and predictor

class MLP(nn.Module):
    def __init__(self, dim, projection_size, hidden_size = 512):#4096):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_size),
            #TODO nn.BatchNorm1d(hidden_size),
            # nn.ReLU(inplace=True),
            nn.PReLU(),
            nn.Linear(hidden_size, projection_size)
        )

    def forward(self, x):
        return self.net(x)

# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets

class NetWrapper(nn.Module):
    def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.projector = None
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size

        self.hidden = None
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, __, output):
        self.hidden = flatten(output)

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    @singleton('projector')
    def _get_projector(self, hidden):
        _, dim = hidden.shape
        projector = MLP(dim, self.projection_size, self.projection_hidden_size)
        return projector.to(hidden)

    def get_representation(self, adj, feat):
        if not self.hook_registered:
            self._register_hook()

        if self.layer == -1:
            return self.net(adj, feat)

        _ = self.net(adj, feat)
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    def forward(self, adj, feat):
        representation = self.get_representation(adj, feat)
        projector = self._get_projector(representation)
        projection = projector(representation)
        return projection


# main class
class IGSD(nn.Module):
    def __init__(self, online_encoder, sup_encoder, feat_dim, hidden_layer = -2, projection_size = 256, \
                 projection_hidden_size = 4096, augment_fn = None, moving_average_decay = 0.99):
        super().__init__()
        #self.augment = default(augment_fn, DEFAULT_AUG)
        #self.augment = graph_aug(augment_fn, DEFAULT_AUG)
        self.online_encoder = online_encoder #NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
        self.encoder = sup_encoder
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)
        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
        #self.diff_projector = MLP(4*projection_size, projection_size, projection_hidden_size)
        '''
        self.gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
                    normalization_out='col',
                    diffusion_kwargs=dict(method='heat', t=3), #(method='ppr', alpha=0.2),
                    sparsification_kwargs=dict(method='topk', k=128,
                                               dim=0), exact=True)
        '''
        #self.gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
        #            normalization_out='col',
        #            diffusion_kwargs=dict(method='ppr', alpha=0.2)) #
        # sparsification_kwargs=dict(method='topk', k=16,dim=0), exact=True)

        self.fc1 = torch.nn.Linear(2*args.hid_dim, args.hid_dim)
        self.fc2 = torch.nn.Linear(args.hid_dim, 1)
        self.init_emb()

    def init_emb(self):
      initrange = -1.5 / args.hid_dim #self.embedding_dim
      for m in self.modules():
          if isinstance(m, nn.Linear):
              torch.nn.init.xavier_uniform_(m.weight.data)
              if m.bias is not None:
                  m.bias.data.fill_(0.0)

    # adj - [B,N,N], feat - [B,N,feat_dim]
    def graph_aug(self, adj, feat):
        return drop_adj(adj, args.drop_prob), drop_feature(feat, args.drop_prob)

    def compute_ppr(self, data, alpha=0.2, self_loop=True):
        # a = nx.convert_matrix.to_numpy_array(to_networkx(data))
        a = to_dense_adj(edge_index=data.edge_index, edge_attr=data.edge_attr).squeeze().cpu().numpy()
        if self_loop:
            a = a + np.eye(a.shape[0])  # A^ = A + I_n
        d = np.diag(np.sum(a, 1))  # D^ = Sigma A^_ii
        dinv = fractional_matrix_power(d, -0.5)  # D^(-1/2)
        at = np.matmul(np.matmul(dinv, a), dinv)  # A~ = D^(-1/2) x A^ x D^(-1/2)
        diff = alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at))  # a(I_n-(1-a)A~)^-1
        edge_index, edge_weight = from_scipy_sparse_matrix(csr_matrix(diff))
        graph = Data(x=data.x, edge_index=edge_index, edge_attr=edge_weight.unsqueeze(1).float()).to(args.device)
        assert np.count_nonzero(graph.edge_attr.cpu().numpy()) == edge_index.shape[1]
        return graph

    @singleton('target_encoder')
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        return target_encoder

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

    def loss_fn(self, x, y):
        x = F.normalize(x, dim=-1, p=2)
        y = F.normalize(y, dim=-1, p=2)
        return 2 - 2 * (x * y).sum(dim=-1)

    def forward(self, data, mask=None):
        out, M = self.encoder(data) #[bs, hid_dim] [20,128]
        out = F.relu(self.fc1(out))
        out = self.fc2(out) # output: [20,1]
        pred = out.view(-1)
        return pred

    def unsup_loss(self, adj, diff, mask=None):
        if args.aug_type == 'diff':
            '''
            # embed respectively for every edge attribute
            diffs = []
            adjs = []
            diff_proj, adj_proj = None, None
            for attr_idx in range(adj.edge_attr.shape[1]):
                adj_copy = copy.deepcopy(adj)
                adj_copy.edge_attr = adj.edge_attr[:,attr_idx].unsqueeze(1)
                #diff = self.compute_ppr(adj_copy)
                #adj_copy = adj_copy.to(args.device)
                #try:
                    #diff = self.gdc(adj_copy)
                    #diff = self.compute_ppr(adj_copy)
                    #diff_attrs.append(diff.edge_attr)
                #except Exception as e:
                #    print(e)
                #    return torch.zeros(1, device=args.device)
                diff = [self.compute_ppr(a) for a in adj_copy.to_data_list()]
                diffs.append(diff)
                adjs.append(adj_copy.to_data_list())
            diff = Batch.from_data_list([item for sublist in diffs for item in sublist])#.to(args.device)
            adj_copy = Batch.from_data_list([item for sublist in adjs for item in sublist])#.to(args.device)
            del adjs, diffs
            '''
            diff_proj, adj_proj = None, None
            diff_proj, _ = self.online_encoder(diff, latent=diff_proj)
            adj_proj,  _ = self.online_encoder(adj, latent=adj_proj)
            #diff_attrs = torch.stack(diff_attrs)
            # TODO debug
            #diff.edge_index = diff.edge_index #[:,:100]
            #diff.edge_attr = diff_attrs.T #[:100,:]

            #online_proj_two, _ = self.online_encoder(diff)
            #online_proj_one, _ = self.online_encoder(adj)
            online_pred_two = self.online_predictor(diff_proj)#online_proj_two
            online_pred_one = self.online_predictor(adj_proj)#online_proj_one

            with torch.no_grad():
                target_encoder = self._get_target_encoder()
                target_proj_one, _ = target_encoder(adj)
                target_proj_two, _ = target_encoder(diff)
        else:
            adj_one, feat_one = self.graph_aug(adj)
            adj_two, feat_two = self.graph_aug(adj)
            online_proj_one, _ = self.online_encoder(adj_one, feat_one)
            online_proj_two, _ = self.online_encoder(adj_two, feat_two)
            with torch.no_grad():
                target_encoder = self._get_target_encoder()
                target_proj_one, _ = target_encoder(adj_one, feat_one)
                target_proj_two, _ = target_encoder(adj_two, feat_two)

        loss_one = self.loss_fn(online_pred_one, target_proj_two.detach())
        loss_two = self.loss_fn(online_pred_two, target_proj_one.detach())

        loss = loss_one + loss_two # [bs, num_node]
        return loss.mean()

    def embed(self, adj, feat, mask=None):
        online_proj, _ = self.online_encoder(adj, feat)
        online_pred = self.online_predictor(online_proj)
        online_pred = online_pred.sum(1) # [bs, num_node, feat_dim] -> [bs, feat_dim]

        return online_pred.detach()


    def transition_matrix(self, edge_index, edge_weight, num_nodes,
                          normalization):
        r"""Calculate the approximate, sparse diffusion on a given sparse
        matrix.

        Args:
            edge_index (LongTensor): The edge indices.
            edge_weight (Tensor): One-dimensional edge weights.
            num_nodes (int): Number of nodes.
            normalization (str): Normalization scheme:

                1. :obj:`"sym"`: Symmetric normalization
                   :math:`\mathbf{T} = \mathbf{D}^{-1/2} \mathbf{A}
                   \mathbf{D}^{-1/2}`.
                2. :obj:`"col"`: Column-wise normalization
                   :math:`\mathbf{T} = \mathbf{A} \mathbf{D}^{-1}`.
                3. :obj:`"row"`: Row-wise normalization
                   :math:`\mathbf{T} = \mathbf{D}^{-1} \mathbf{A}`.
                4. :obj:`None`: No normalization.

        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        if normalization == 'sym':
            row, col = edge_index
            deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        elif normalization == 'col':
            _, col = edge_index
            deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
            deg_inv = 1. / deg
            deg_inv[deg_inv == float('inf')] = 0
            edge_weight = edge_weight * deg_inv[col]
        elif normalization == 'row':
            row, _ = edge_index
            deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
            deg_inv = 1. / deg
            deg_inv[deg_inv == float('inf')] = 0
            edge_weight = edge_weight * deg_inv[row]
        elif normalization is None:
            pass
        else:
            raise ValueError(
                'Transition matrix normalization {} unknown.'.format(
                    normalization))

        return edge_index, edge_weight