import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import scipy.sparse as sp
from models.utils import edgeindex

class GraphAttention(nn.Module):
    def __init__(self, feat_dim, hidden_dim, num_heads, output_dim, dropout, task_level):
        super(GraphAttention, self).__init__()
        self.adj = None
        self.query_edges = None

        self.gat1 = GATConv(feat_dim, hidden_dim, num_heads, dropout=dropout)
        self.gat2 = GATConv(hidden_dim*num_heads, output_dim, 1, dropout=dropout)
        self.relu = nn.ReLU()
        if task_level != "node":
            self.fc2_node_edge = nn.Linear(output_dim * 2, output_dim)

    def forward(self, feature, adj):
        x = self.gat1(feature, adj)
        x = self.relu(x)
        if self.query_edges is None:
            output = self.gat2(x, adj)
        else:
            x = self.gat2(x, adj)
            x = torch.cat((x[self.query_edges[:, 0]], x[self.query_edges[:, 1]]), dim=-1)
            output = self.fc2_node_edge(x)
        return output

class GAT(nn.Module):
    def __init__(self, feat_dim, hidden_dim, num_heads, output_dim, dropout, task_level):
        super(GAT, self).__init__()
        self.base_model = GraphAttention(feat_dim=feat_dim, hidden_dim=hidden_dim, num_heads=num_heads, output_dim=output_dim, dropout=dropout, task_level=task_level)
        self.processed_feature = None
    def preprocess(self, adj, feature, type="default", tensor_feat=True):
        if type == "tensor":
            if adj.is_sparse:
                adj = adj.to_dense()
            adj_norm = edgeindex(adj)
        else:
            adj_norm = edgeindex((adj))
            # adj_norm = adj
            adj_norm = torch.tensor(adj_norm, dtype=torch.int64)
        self.processed_adj = adj_norm
        if feature is not None:
            if tensor_feat:
                self.processed_feature = torch.FloatTensor(feature)
            else:
                self.processed_feature = feature
        
        else: 
            if self.naive_graph_op is not None:
                self.base_model.adj = self.naive_graph_op.construct_adj(adj)
                self.base_model.adj = torch.tensor(self.base_model.adj, dtype=torch.int64)
            self.pre_msg_learnable = False
            self.processed_feature = torch.FloatTensor(feature)
    def model_forward(self, idx, device, ori=None):
        return self.forward(idx, device, ori)
    def forward(self, idx, device, ori):
        processed_feature = None
        processed_feature = self.processed_feature.to(device)
        processed_adj = self.processed_adj.to(device)
        if ori is not None:
            self.base_model.query_edges = ori
            
        output = self.base_model(processed_feature, processed_adj)
        return output[idx] if idx is not None else output