from torch import nn
import torch.nn.functional as F
import random
import torch
import numpy as np
from torch.nn.parameter import Parameter


class BilinearSim(torch.nn.Module):
    r"""Bilinear similarity between two tensors."""
    def __init__(self, in_channels: int = 512, struct_in_channels: int = 512, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32):
        super().__init__()
        self.sim = torch.nn.Linear(in_channels, in_channels)
        self.struct_sim = torch.nn.Linear(struct_in_channels, struct_in_channels)
        factory_kwargs = {"device": device, "dtype": dtype}
        self.weight = Parameter(
            torch.empty((1, in_channels, in_channels), **factory_kwargs)
        )
        self.bias = Parameter(torch.empty(1, **factory_kwargs))

    def forward(self, x, y, struct_patterns, struct_emb):
        r"""
        Args:
            x: [batch_size, dim]
            y: [batch_size, dim]

        Returns:
            similarity scores of (x, y): [batch_size, 1]
        """
        item_list = []
        for idx in range(len(x)):
            struct_out = self.struct_sim(struct_patterns[idx])
            logits = torch.matmul(struct_out, struct_emb.T)
            logits = F.softmax(logits, dim=0).max(dim=0)[0]
            item_list.append(logits)
        item_list = torch.stack(item_list, dim=0)
        return item_list


class ARC_New(nn.Module):
    def __init__(self, in_feats, h_feats=32, num_layers=2, dropout_rate=0, activation='ReLU', beta=1, num_hops=4, device='cpu', **kwargs):
        super(ARC_New, self).__init__()
        self.layers = nn.ModuleList()
        self.act = getattr(nn, activation)()
        self.num_hops = num_hops
        self.h_feats = h_feats
        self.st_dim = 10
        struct_dim = self.st_dim * num_hops
        if num_layers == 0:
            return
        self.layers.append(nn.Linear(in_feats, h_feats))
        self.struct_layers = nn.ModuleList()
        self.struct_layers.append(nn.Linear(self.st_dim, self.st_dim))
        for i in range(1, num_layers - 1):
            self.layers.append(nn.Linear(h_feats, h_feats))
        for i in range(1, num_hops + 1):
            self.struct_layers.append(nn.Linear(self.st_dim, self.st_dim))
        self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
        self.beta = beta
        embedding_dim = h_feats * num_hops
        self.embedding_dim = embedding_dim
        self.Wq = nn.Linear(embedding_dim, embedding_dim // 2)
        self.Wk = nn.Linear(embedding_dim, embedding_dim // 2)
        self.Wq_struct = nn.Linear(struct_dim, struct_dim)
        self.Wk_struct = nn.Linear(struct_dim, struct_dim)
        self.domain_sim = BilinearSim(embedding_dim,  struct_dim)
        self.criteria = nn.TripletMarginLoss(margin=0.2, p=2.0, eps=1e-06, swap=True)
        self.criteria_struct = nn.TripletMarginLoss(margin=0.1, p=2.0, eps=1e-06, swap=True)

    def get_embedding(self, h, wl_pos):
        x_list = h.x_list
        struct_feat = h.one_node_features
        struct_feat_list = []
        # Z^{[l]} = MLP(X^{[l]}
        for i, layer in enumerate(self.layers):
            if i != 0:
                x_list = [self.dropout(x) for x in x_list]
            x_list = [layer(x) for x in x_list]
            if i != len(self.layers) - 1:
                x_list = [self.act(x) for x in x_list]
            ## struct_feat and x_list share the similar neural networks
        for i, layer in enumerate(self.struct_layers):
            out = torch.matmul(h.adj, layer(struct_feat))
            if i != len(self.struct_layers) - 1:
                out = self.act(out)
            struct_feat = out
            struct_feat_list.append(out)

        residual_list = []
        residual_struct = []
        # Z^{[0]}
        first_element = x_list[0]
        first_struct_feat = struct_feat_list[0]
        for h_i in x_list[1:]:
            # R^{[l]} = Z^{[l]}-Z^{[0]}
            residual_list.append(h_i - first_element)
            # first_element = residual_list[-1]
        for h_i in struct_feat_list[1:]:
            residual_struct.append(h_i - first_struct_feat)
            # first_struct_feat = residual_struct[-1]
        # H = [R^{[1]} || ... || R^{[L]}]
        struct_feat = torch.hstack(residual_struct)
        X = torch.hstack(residual_list)
        return X, struct_feat


    def forward(self, h_train, struct_train, wl_pos, y_train, patterns_list, struct_patterns, num_prompt=50):
        # get the indices of both normal nodes and abnormal nodes
        anomaly_indices = torch.nonzero((y_train == 1)).squeeze(1).tolist()
        all_normal_indices = torch.nonzero((y_train == 0)).squeeze(1).tolist()
        num_prompt = min(num_prompt, len(anomaly_indices))
        if len(all_normal_indices) < num_prompt:
            num_prompt = len(all_normal_indices)
        normal_indices = random.sample(all_normal_indices, num_prompt)
        if len(anomaly_indices) > num_prompt:
            anomaly_indices = random.sample(anomaly_indices, num_prompt)
        # gather both anomaly and normal nodes embeddings
        anomaly_emb = h_train[anomaly_indices]
        normal_emb = h_train[normal_indices]
        y_positive = torch.ones([len(normal_indices)]).to(y_train.device)
        y_negative = -torch.ones([len(anomaly_indices)]).to(y_train.device)
        struct_normal_emb = struct_train[normal_indices]
        struct_anomaly_emb = struct_train[anomaly_indices]
        normal_dom_sim = self.domain_sim(patterns_list, normal_emb, struct_patterns, struct_normal_emb)
        anomaly_dom_sim = self.domain_sim(patterns_list, anomaly_emb, struct_patterns, struct_anomaly_emb)
        tilde_normal_embeds = self.cross_attention(normal_emb, patterns_list, normal_dom_sim, self.Wq, self.Wk)
        tilde_anomaly_embeds = self.cross_attention(anomaly_emb, patterns_list, anomaly_dom_sim, self.Wq, self.Wk)
        tilde_struct_emb = self.cross_attention(struct_normal_emb, struct_patterns, normal_dom_sim, self.Wq_struct, self.Wk_struct)
        loss = F.cosine_embedding_loss(normal_emb, tilde_normal_embeds,y_positive) + F.cosine_embedding_loss(anomaly_emb, tilde_anomaly_embeds, y_negative) + F.cosine_embedding_loss(normal_emb, tilde_anomaly_embeds, y_negative)
        loss += self.criteria(tilde_normal_embeds, normal_emb, anomaly_emb) + self.beta * self.criteria_struct(tilde_struct_emb, struct_normal_emb, struct_anomaly_emb)
        return loss


    def patterns_extraction(self, h_train, struct_emb, adj_train, y_train, num_prompt=50):
        normal_indices = torch.nonzero((y_train == 0)).squeeze(1).tolist()
        if len(normal_indices) > num_prompt:
            normal_indices = random.sample(normal_indices, num_prompt)
        # gather both anomaly and normal nodes embeddings
        normal_emb = h_train[normal_indices]
        struct_emb = struct_emb[normal_indices]
        return normal_emb, struct_emb


    def patterns_extraction_for_test_graph(self, h_test, struct_emb, num_prompt=10):
        all_indices  = np.arange(h_test.shape[0]).tolist()
        selected_indices = random.sample(all_indices, num_prompt)
        selected_node_emb = h_test[selected_indices]
        struct_emb = struct_emb[selected_indices]
        return selected_node_emb, struct_emb


    def inference(self, patterns, h_test, adj, struct_patterns, struct_test):
        dom_sim = self.domain_sim(patterns, h_test, struct_patterns, struct_test)
        query_embeds = self.cross_attention(h_test, patterns, dom_sim, self.Wq, self.Wk)
        query_struct_embeds = self.cross_attention(struct_test, struct_patterns, dom_sim, self.Wq_struct, self.Wk_struct)
        query_score = (torch.sqrt(torch.sum((query_embeds - h_test) ** 2, dim=1)) +
                       self.beta * torch.sqrt(torch.sum((query_struct_embeds - struct_test) ** 2, dim=1)))
        return query_score, dom_sim


    def cross_attention(self, query_X, support_X, dom_sim, Wq, Wk):
        dom_sim = dom_sim.T
        Q = F.leaky_relu(Wq(query_X))  # query
        emb_list = 0
        for idx in range(len(support_X)):
            K = F.leaky_relu(Wk(support_X[idx]))  # key
            attention_scores = torch.matmul(Q, K.T) / torch.sqrt(
                torch.tensor(self.embedding_dim, dtype=torch.float32))
            k = int(attention_scores.shape[1] * 0.1)
            # Get the indices of the k smallest values along each row
            _, indices = torch.topk(attention_scores, k, dim=1, largest=False)
            # Create a mask of ones
            mask = torch.ones_like(attention_scores)
            # Use advanced indexing to set the k smallest values to 0
            row_indices = torch.arange(attention_scores.size(0)).unsqueeze(1)  # shape (600, 1)
            mask[row_indices, indices] = 0
            attention_scores[row_indices, indices] = float('-inf')
            attention_weights = F.softmax(attention_scores/ self.temperature, dim=1)
            weighted_query_embeddings = torch.matmul(attention_weights, support_X[idx])
            emb_list += dom_sim[:, idx] * weighted_query_embeddings.T
        emb_list = emb_list / len(support_X)
        return emb_list.T
