import copy

import torch
from torch import nn
from torch.nn import BatchNorm1d, Parameter
from torch.nn.init import ones_, zeros_

import dgl
from dgl.nn.pytorch.conv import GraphConv, SAGEConv


class LayerNorm(nn.Module):
    def __init__(self, in_channels, eps=1e-5, affine=True):
        super().__init__()
        self.in_channels = in_channels
        self.eps = eps

        if affine:
            self.weight = Parameter(torch.Tensor(in_channels))
            self.bias = Parameter(torch.Tensor(in_channels))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)

        self.reset_parameters()

    def reset_parameters(self):
        ones_(self.weight)
        zeros_(self.bias)

    def forward(self, x, batch=None):
        device = x.device
        if batch is None:
            x = x - x.mean()
            out = x / (x.std(unbiased=False) + self.eps)
        else:
            batch_size = int(batch.max()) + 1
            batch_idx = [batch == i for i in range(batch_size)]
            norm = (
                torch.tensor([i.sum() for i in batch_idx], dtype=x.dtype)
                .clamp_(min=1)
                .to(device)
            )
            norm = norm.mul_(x.size(-1)).view(-1, 1)
            tmp_list = [x[i] for i in batch_idx]
            mean = (
                torch.concat([i.sum(0).unsqueeze(0) for i in tmp_list], dim=0)
                .sum(dim=-1, keepdim=True)
                .to(device)
            )
            mean = mean / norm
            x = x - mean.index_select(0, batch.long())
            var = (
                torch.concat(
                    [(i * i).sum(0).unsqueeze(0) for i in tmp_list], dim=0
                )
                .sum(dim=-1, keepdim=True)
                .to(device)
            )
            var = var / norm
            out = x / (var + self.eps).sqrt().index_select(0, batch.long())

        if self.weight is not None and self.bias is not None:
            out = out * self.weight + self.bias

        return out

    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_channels})"


class MLP_Predictor(nn.Module):
    r"""MLP used for predictor. The MLP has one hidden layer.
    Args:
        input_size (int): Size of input features.
        output_size (int): Size of output features.
        hidden_size (int, optional): Size of hidden layer. (default: :obj:`4096`).
    """

    def __init__(self, input_size, output_size, hidden_size=512):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size, bias=True),
            nn.PReLU(1),
            nn.Linear(hidden_size, output_size, bias=True),
        )
        self.reset_parameters()

    def forward(self, x):
        return self.net(x)

    def reset_parameters(self):
        # kaiming_uniform
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.reset_parameters()


class GCN(nn.Module):
    def __init__(self, layer_sizes, batch_norm_mm=0.99):
        super(GCN, self).__init__()

        self.layers = nn.ModuleList()
        for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]):
            self.layers.append(GraphConv(in_dim, out_dim))
            self.layers.append(BatchNorm1d(out_dim, momentum=batch_norm_mm))
            self.layers.append(nn.PReLU())

    def forward(self, g):
        x = g.ndata["feat"]
        for layer in self.layers:
            if isinstance(layer, GraphConv):
                x = layer(g, x)
            else:
                x = layer(x)
        return x

    def reset_parameters(self):
        for layer in self.layers:
            if hasattr(layer, "reset_parameters"):
                layer.reset_parameters()


class GraphSAGE_GCN(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()

        input_size, hidden_size, embedding_size = layer_sizes

        self.convs = nn.ModuleList(
            [
                SAGEConv(input_size, hidden_size, "mean"),
                SAGEConv(hidden_size, hidden_size, "mean"),
                SAGEConv(hidden_size, embedding_size, "mean"),
            ]
        )

        self.skip_lins = nn.ModuleList(
            [
                nn.Linear(input_size, hidden_size, bias=False),
                nn.Linear(input_size, hidden_size, bias=False),
            ]
        )

        self.layer_norms = nn.ModuleList(
            [
                LayerNorm(hidden_size),
                LayerNorm(hidden_size),
                LayerNorm(embedding_size),
            ]
        )

        self.activations = nn.ModuleList(
            [
                nn.PReLU(),
                nn.PReLU(),
                nn.PReLU(),
            ]
        )

    def forward(self, g):
        x = g.ndata["feat"]
        if "batch" in g.ndata.keys():
            batch = g.ndata["batch"]
        else:
            batch = None

        h1 = self.convs[0](g, x)
        h1 = self.layer_norms[0](h1, batch)
        h1 = self.activations[0](h1)

        x_skip_1 = self.skip_lins[0](x)
        h2 = self.convs[1](g, h1 + x_skip_1)
        h2 = self.layer_norms[1](h2, batch)
        h2 = self.activations[1](h2)

        x_skip_2 = self.skip_lins[1](x)
        ret = self.convs[2](g, h1 + h2 + x_skip_2)
        ret = self.layer_norms[2](ret, batch)
        ret = self.activations[2](ret)
        return ret

    def reset_parameters(self):
        for m in self.convs:
            m.reset_parameters()
        for m in self.skip_lins:
            m.reset_parameters()
        for m in self.activations:
            m.weight.data.fill_(0.25)
        for m in self.layer_norms:
            m.reset_parameters()


class BGRL(nn.Module):
    r"""BGRL architecture for Graph representation learning.
    Args:
        encoder (torch.nn.Module): Encoder network to be duplicated and used in both online and target networks.
        predictor (torch.nn.Module): Predictor network used to predict the target projection from the online projection.
    .. note::
        `encoder` must have a `reset_parameters` method, as the weights of the target network will be initialized
        differently from the online network.
    """

    def __init__(self, encoder, predictor):
        super(BGRL, self).__init__()
        # online network
        self.online_encoder = encoder
        self.predictor = predictor

        # target network
        self.target_encoder = copy.deepcopy(encoder)

        # reinitialize weights
        self.target_encoder.reset_parameters()

        # stop gradient
        for param in self.target_encoder.parameters():
            param.requires_grad = False

    def trainable_parameters(self):
        r"""Returns the parameters that will be updated via an optimizer."""
        return list(self.online_encoder.parameters()) + list(
            self.predictor.parameters()
        )

    @torch.no_grad()
    def update_target_network(self, mm):
        r"""Performs a momentum update of the target network's weights.
        Args:
            mm (float): Momentum used in moving average update.
        """
        for param_q, param_k in zip(
            self.online_encoder.parameters(), self.target_encoder.parameters()
        ):
            param_k.data.mul_(mm).add_(param_q.data, alpha=1.0 - mm)

    def forward(self, online_x, target_x):
        # forward online network
        online_y = self.online_encoder(online_x)

        # prediction
        online_q = self.predictor(online_y)

        # forward target network
        with torch.no_grad():
            target_y = self.target_encoder(target_x).detach()
        return online_q, target_y


def compute_representations(net, dataset, device):
    r"""Pre-computes the representations for the entire data.
    Returns:
        [torch.Tensor, torch.Tensor]: Representations and labels.
    """
    net.eval()
    reps = []
    labels = []

    if len(dataset) == 1:
        g = dataset[0]
        g = dgl.add_self_loop(g)
        g = g.to(device)
        with torch.no_grad():
            reps.append(net(g))
            labels.append(g.ndata["label"])
    else:
        for g in dataset:
            # forward
            g = g.to(device)
            with torch.no_grad():
                reps.append(net(g))
                labels.append(g.ndata["label"])

    reps = torch.cat(reps, dim=0)
    labels = torch.cat(labels, dim=0)
    return [reps, labels]
