import torch
import torch.nn.functional as F
import math

import torch.nn as nn

from torch_geometric.utils import softmax

class SparsificationDecoder(torch.nn.Module):
    def __init__(
        self,
        **model_params
    ):
        super().__init__()
        self.model_params = model_params
        self.embedding_dim = self.model_params['embedding_dim']
        self.qkv_dim = self.model_params['qkv_dim']
        self.n_heads = self.model_params['head_num']

        hyper_input_dim = 2
        hyper_hidden_embd_dim = 256
        self.embd_dim = 2

        if self.model_params['aug_query_weights']:
            self.hyper_output_dim = 5*self.embd_dim
        else:
            self.hyper_output_dim =  2*self.embd_dim

        self.tanh = self.model_params['tanh']
        self.tanh_clipping = self.model_params['logit_clipping']
        
        self.hyper_fc1 = nn.Linear(hyper_input_dim, hyper_hidden_embd_dim, bias=True)
        self.hyper_fc2 = nn.Linear(hyper_hidden_embd_dim, hyper_hidden_embd_dim, bias=True)
        self.hyper_fc3 = nn.Linear(hyper_hidden_embd_dim, self.hyper_output_dim, bias=True)
        
        self.hyper_Wq = nn.Linear(self.embd_dim, self.embedding_dim * self.n_heads * self.qkv_dim, bias=False)
        self.hyper_Wk = nn.Linear(self.embd_dim, self.embedding_dim * self.n_heads * self.qkv_dim, bias=False)

        if self.model_params['aug_query_weights']:
            self.hyper_Wq_avg = nn.Linear(self.embd_dim, self.embedding_dim * self.embedding_dim, bias=False)
            self.hyper_Wq_max = nn.Linear(self.embd_dim, self.embedding_dim * self.embedding_dim, bias=False)
            self.hyper_Wq_min = nn.Linear(self.embd_dim, self.embedding_dim * self.embedding_dim, bias=False)

    def assign(self, pref):
        hyper_embd = self.hyper_fc1(pref)
        hyper_embd = self.hyper_fc2(hyper_embd)
        mid_embd = self.hyper_fc3(hyper_embd)
        self.pref = pref
        
        self.Wq = self.hyper_Wq(mid_embd[:self.embd_dim]).reshape(self.embedding_dim, self.n_heads * self.qkv_dim)
        self.Wk = self.hyper_Wk(mid_embd[self.embd_dim:2 * self.embd_dim]).reshape(self.embedding_dim, self.n_heads * self.qkv_dim)

        if self.model_params['aug_query_weights']:
            self.Wq_avg = self.hyper_Wq_avg(mid_embd[2 * self.embd_dim: 3 * self.embd_dim]).reshape(self.embedding_dim, self.embedding_dim)
            self.Wq_max = self.hyper_Wq_max(mid_embd[3 * self.embd_dim: 4 * self.embd_dim]).reshape(self.embedding_dim, self.embedding_dim)
            self.Wq_min = self.hyper_Wq_min(mid_embd[4 * self.embd_dim: 5 * self.embd_dim]).reshape(self.embedding_dim, self.embedding_dim)

    def reset(self, dists, edge_emb, edge_indices):
        # dists: (B * E, 2)

        #if self.model_params["training_method"] == "Chb":
            #self.dists = torch.maximum(self.pref[0] * dists[:, 0], self.pref[1] * dists[:, 1])
        if True: #self.model_params["training_method"] == "Linear":
            self.dists = self.pref[0] * dists[:, 0] + self.pref[1] * dists[:, 1]  # from (B * E, 2) to (B * E)
        elif self.model_params["training_method"] == "aChb":
            alpha = 1e-3
            self.dists = alpha * (self.pref[0] * dists[:, 0] + self.pref[1] * dists[:, 1])  + torch.maximum(self.pref[0] * dists[:, 0], self.pref[1] * dists[:, 1])

        _, self.shared_nodes = torch.unique(edge_indices, dim=1, return_inverse=True)
        groups = torch.unique(self.shared_nodes)

        self.k = edge_emb

        group_sums = torch.zeros((groups.shape[0], edge_emb.shape[1]))
        group_sums.index_add_(0, self.shared_nodes, edge_emb)

        group_counts = torch.bincount(self.shared_nodes).clamp(min=1).unsqueeze(-1)
        group_means = group_sums / group_counts

        if self.model_params['aug_query']: # query with max and min embeddings too
            group_max = torch.full((groups.shape[0], edge_emb.shape[1]), float('-inf'))
            group_max = torch.scatter_reduce(
                group_max,
                0,        
                self.shared_nodes[:, None].expand(-1, edge_emb.shape[1]),  # Expand indices
                edge_emb,   # Values to insert
                reduce="amax"  # Perform max operation
            )
            group_min = torch.full((groups.shape[0], edge_emb.shape[1]), float('inf'))
            group_min = torch.scatter_reduce(
                group_min,
                0,        
                self.shared_nodes[:, None].expand(-1, edge_emb.shape[1]),  # Expand indices
                edge_emb,   # Values to insert
                reduce="amin"  # Perform max operation
            )
            if self.model_params['aug_query_weights']:
                q_avg = F.linear(group_means[self.shared_nodes], self.Wq_avg)
                q_min = F.linear(group_min[self.shared_nodes], self.Wq_min)
                q_max = F.linear(group_max[self.shared_nodes], self.Wq_max)
                self.q = q_avg + q_min + q_max # (B * E, Embedding_dim)
            else:
                self.q = group_means[self.shared_nodes] + group_max[self.shared_nodes] + group_min[self.shared_nodes] # (B * E, Embedding_dim)
        else:
            self.q = group_means[self.shared_nodes] # (B * E, Embedding_dim)

    def forward(self):
        final_q = F.linear(self.q, self.Wq).reshape(-1, self.n_heads, self.qkv_dim)
        logits_k = F.linear(self.k, self.Wk).reshape(-1, self.n_heads, self.qkv_dim)

        if self.model_params['only_by_distance']:
            score = -self.dists
        else: 
            score = torch.multiply(final_q, logits_k).sum(dim=2).mean(dim=1) / math.sqrt(self.embedding_dim) - self.dists / math.sqrt(2)

        if self.tanh:
            score = self.tanh_clipping * torch.tanh(score)

        probs = softmax(score, self.shared_nodes)

        # probs: (B * E, 1)
        return probs.unsqueeze(1)
