import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize

from greatx.nn.layers import GCNConv, Sequential, activations


def inject_noise(x, scale_noise):
    """Inject Gaussian noise into embeddings."""
    noise = torch.randn_like(x)
    return x + scale_noise * noise


class NoisyGCN(nn.Module):
    def __init__(self, in_channels: int, out_channels: int,
                 hids: List[int] = [16], acts: List[str] = ['relu'],
                 dropout: float = 0.5, bias: bool = True, bn: bool = False,
                 normalize: bool = True, noise_ratio_1: float = 0.1,
                 attention: bool = False):
        super().__init__()

        assert len(hids) == len(acts), "Length of hids and acts must match."

        self.noise_ratio_1 = noise_ratio_1
        self.dropout = dropout
        self.attention = attention
        self.normalize = normalize

        layers = []
        in_c = in_channels
        for hid, act in zip(hids, acts):
            layers.append(GCNConv(in_c, hid, bias=bias, normalize=normalize))
            if bn:
                layers.append(nn.BatchNorm1d(hid))
            if act:
                layers.append(activations.get(act))
            layers.append(nn.Dropout(dropout))
            in_c = hid
        layers.append(GCNConv(in_c, out_channels, bias=bias, normalize=normalize))

        self.conv = Sequential(*layers)

    def reset_parameters(self):
        self.conv.reset_parameters()

    def compute_attention(self, x, edge_index, num_nodes):
        """Compute attention weights using cosine similarity."""
        row, col = edge_index
        x_cpu = x.detach().cpu().numpy()

        sim_matrix = cosine_similarity(x_cpu)
        sim = sim_matrix[row.cpu().numpy(), col.cpu().numpy()]
        sim[sim < 0.1] = 0  # threshold low similarities

        edge_weight = torch.tensor(sim, dtype=torch.float32, device=x.device)
        return edge_index, edge_weight

    def forward(self, x, edge_index, edge_weight=None):
        num_nodes = x.size(0)

        # Apply attention before first layer
        if self.attention:
            edge_index, edge_weight = self.compute_attention(x, edge_index, num_nodes)

        # First layer (inject noise **after this**)
        x = self.conv[0](x, edge_index, edge_weight)  # First GCNConv layer

        if self.training:
            x = inject_noise(x, self.noise_ratio_1)

        # Remaining layers in self.conv[1:] (BatchNorm, act, dropout, 2nd GCNConv, etc.)
        for layer in self.conv[1:]:
            if isinstance(layer, (GCNConv)):
                x = layer(x, edge_index, edge_weight)
            else:
                x = layer(x)

        return F.log_softmax(x, dim=1)
