from typing import Optional
from itertools import chain
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from .acm_gin import ACM_GIN_model
from torch_geometric.utils import dropout_edge
from torch_geometric.utils import add_self_loops


def sce_loss(x, y, alpha=3):
    x = F.normalize(x, p=2, dim=-1)
    y = F.normalize(y, p=2, dim=-1)

    loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
    loss = loss.mean()

    return loss


def setup_module(
    m_type,
    in_dim,
    out_dim,
    num_hidden,
    num_layers,
    activation,
    batchnorm,
) -> nn.Module:

    if m_type == "acm_gin":
        mod = ACM_GIN_model(
            int(in_dim),
            int(out_dim),
            num_layers,
            int(num_hidden),
            batchnorm,
            activation=activation,
        )
    else:
        raise NotImplementedError

    return mod


class PreModel(nn.Module):
    def __init__(
        self,
        in_dim: int,
        edge_in_dim: int,
        num_hidden: int,
        num_layers: int,
        nhead: int,
        nhead_out: int,
        activation: str,
        feat_drop: float,
        attn_drop: float,
        negative_slope: float,
        residual: bool,
        norm: Optional[str],
        mask_rate: float = 0.3,
        encoder_type: str = "gat",
        decoder_type: str = "gat",
        loss_fn: str = "sce",
        drop_edge_rate: float = 0.0,
        replace_rate: float = 0.1,
        alpha_l: float = 2,
        concat_hidden: bool = False,
        batchnorm=False,
    ):
        super(PreModel, self).__init__()
        self._mask_rate = mask_rate
        self._encoder_type = encoder_type
        self._decoder_type = decoder_type
        self._drop_edge_rate = drop_edge_rate
        self._output_hidden_size = num_hidden
        self._concat_hidden = concat_hidden

        self._replace_rate = replace_rate
        self._mask_token_rate = 1 - self._replace_rate

        assert num_hidden % nhead == 0
        assert num_hidden % nhead_out == 0

        enc_num_hidden = num_hidden
        enc_nhead = 1

        dec_in_dim = num_hidden
        dec_num_hidden = num_hidden

        # Build encoder
        self.encoder = setup_module(
            m_type=encoder_type,
            in_dim=in_dim,
            out_dim=enc_num_hidden,
            num_hidden=enc_num_hidden,
            num_layers=num_layers,
            activation=activation,
            batchnorm=batchnorm,
        )

        # Build decoder for attribute prediction
        self.decoder = setup_module(
            m_type=decoder_type,
            in_dim=dec_in_dim,
            out_dim=in_dim,
            num_hidden=dec_num_hidden,
            num_layers=1,
            activation=activation,
            batchnorm=batchnorm,
        )

        self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim))
        if concat_hidden:
            self.encoder_to_decoder = nn.Linear(
                dec_in_dim * num_layers, dec_in_dim, bias=False
            )
        else:
            self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)

        # Setup loss function
        self.criterion = self.setup_loss_fn(loss_fn, alpha_l)

    @property
    def output_hidden_dim(self):
        return self._output_hidden_size

    def setup_loss_fn(self, loss_fn, alpha_l):
        if loss_fn == "mse":
            criterion = nn.MSELoss()
        elif loss_fn == "sce":
            criterion = partial(sce_loss, alpha=alpha_l)
        else:
            raise NotImplementedError
        return criterion

    def encoding_mask_noise(self, x, mask_rate=0.3, virtual_node_index=None):
        num_nodes = x.shape[0]
        all_indices = torch.arange(num_nodes, device=x.device)

        # Remove virtual node index from masking candidates
        if virtual_node_index is not None:
            all_indices = all_indices[~torch.isin(all_indices, virtual_node_index)]

        perm = all_indices[torch.randperm(len(all_indices), device=x.device)]

        # random masking
        num_mask_nodes = int(mask_rate * len(perm))
        mask_nodes = perm[:num_mask_nodes]
        keep_nodes = perm[num_mask_nodes:]

        out_x = x.clone()

        if self._replace_rate > 0:
            num_noise_nodes = int(self._replace_rate * num_mask_nodes)
            perm_mask = torch.randperm(num_mask_nodes, device=x.device)
            token_nodes = mask_nodes[
                perm_mask[: int(self._mask_token_rate * num_mask_nodes)]
            ]
            noise_nodes = mask_nodes[
                perm_mask[-int(self._replace_rate * num_mask_nodes) :]
            ]
            noise_to_be_chosen = torch.randperm(len(perm), device=x.device)[
                :num_noise_nodes
            ]
            noise_to_be_chosen = all_indices[noise_to_be_chosen]

            out_x[token_nodes] = 0.0
            out_x[noise_nodes] = x[noise_to_be_chosen]
        else:
            token_nodes = mask_nodes
            out_x[mask_nodes] = 0.0

        out_x[token_nodes] += self.enc_mask_token

        return out_x, (mask_nodes, keep_nodes)

    def forward(self, batch):
        # ---- attribute reconstruction ----
        x, edge_index, edge_attr, virtual_node_index, batch = (
            batch.x,
            batch.edge_index,
            batch.edge_attr,
            getattr(batch, "virtual_node_index", None),
            batch.batch,
        )
        loss = self.mask_attr_prediction(
            x, edge_index, edge_attr, batch, virtual_node_index
        )
        return loss

    def mask_attr_prediction(self, x, edge_index, edge_attr, batch, virtual_node_index):

        use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(
            x,
            self._mask_rate,
            virtual_node_index,
        )

        if self._drop_edge_rate > 0:
            use_edge_index, masked_edges = dropout_edge(
                edge_index, self._drop_edge_rate
            )
            use_edge_attr = edge_attr[masked_edges]
            use_edge_index, use_edge_attr = add_self_loops(
                use_edge_index, use_edge_attr, fill_value="min"
            )
        else:
            use_edge_index = edge_index
            use_edge_attr = edge_attr

        enc_rep, all_hidden = self.encoder(
            use_x, use_edge_index, use_edge_attr, return_hidden=True
        )
        if self._concat_hidden:
            enc_rep = torch.cat(all_hidden, dim=1)

        # ---- attribute reconstruction ----
        rep = self.encoder_to_decoder(enc_rep)

        if self._decoder_type not in ("mlp", "linear"):
            # * remask, re-mask
            rep[mask_nodes] = 0

        if self._decoder_type in ("mlp", "linear"):
            recon = self.decoder(rep)
        else:
            recon = self.decoder(rep, use_edge_index, use_edge_attr)

        x_init = x[mask_nodes]
        x_rec = recon[mask_nodes]

        loss = self.criterion(x_rec, x_init)

        return loss

    def embed(self, x, edge_index, edge_attr, batch):
        if self._concat_hidden:
            enc_rep, all_hidden = self.encoder(
                x, edge_index, edge_attr, return_hidden=True
            )
            enc_rep = torch.cat(all_hidden, dim=1)
        else:
            enc_rep = self.encoder(x, edge_index, edge_attr)
        rep = self.encoder_to_decoder(enc_rep)
        return rep

    @property
    def enc_params(self):
        return self.encoder.parameters()

    @property
    def dec_params(self):
        return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])
