import torch
from torch_geometric.data import HeteroData
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import negative_sampling

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv, Linear, to_hetero
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import softmax
from torch_scatter import scatter_softmax,scatter

warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def edge_type_to_str(edge_type: tuple) -> str:
    return '__'.join(edge_type)


class WeightedHGTConv(HGTConv):
    def __init__(self, in_channels, out_channels, metadata, heads=4):
        super().__init__(
            in_channels, 
            out_channels, 
            metadata, 
            heads=heads,
            use_edge_attr=False  
        )
        
        self.heads = heads
        self.dim_per_head = out_channels // heads  

        self.k_lin = nn.ModuleDict({
            node_type: Linear(in_channels, heads * self.dim_per_head)
            for node_type in metadata[0]
        })
        self.q_lin = nn.ModuleDict({
            node_type: Linear(in_channels, heads * self.dim_per_head)
            for node_type in metadata[0]
        })
        self.v_lin = nn.ModuleDict({
            node_type: Linear(in_channels, heads * self.dim_per_head)
            for node_type in metadata[0]
        })
        
        self.edge_type_strs = [self.edge_type_to_str(et) for et in metadata[1]]
        
        self.a_rel = nn.ModuleDict({
            et_str: nn.Linear(self.dim_per_head, 1, bias=False) 
            for et_str in self.edge_type_strs
        })
        
        self.edge_weight_lin = torch.nn.ModuleDict({
            et: Linear(1, heads) for et in self.edge_type_strs
        })
        self.edge_scale = torch.nn.Parameter(torch.ones(heads))
        
        self.a_lin = nn.ModuleDict({
            node_type: nn.Linear(heads * self.dim_per_head, out_channels)
            for node_type in metadata[0]
        })

    @staticmethod
    def edge_type_to_str(edge_type):
        return '__'.join(edge_type)
    
    def _softmax(self, att, index):
        return scatter_softmax(att, index, dim=0)

    def forward(self, x_dict, edge_index_dict, edge_weight_dict=None):
        edge_attr_dict = {}
        if edge_weight_dict is not None:
            for edge_type, edge_weight in edge_weight_dict.items():
                et_str = self.edge_type_to_str(edge_type)
                if et_str in self.edge_weight_lin:
                    edge_attr = self.edge_weight_lin[et_str](
                        edge_weight.view(-1, 1)
                    )
                    edge_attr_dict[edge_type] = edge_attr

        return self.hgt_forward(
            x_dict, 
            edge_index_dict, 
            edge_attr_dict=edge_attr_dict
        )

    def hgt_forward(self, x_dict, edge_index_dict, edge_attr_dict=None):
        query_dict, key_dict, value_dict = {}, {}, {}
        for node_type, x in x_dict.items():
            query = self.q_lin[node_type](x).view(-1, self.heads, self.dim_per_head)
            key = self.k_lin[node_type](x).view(-1, self.heads, self.dim_per_head)
            value = self.v_lin[node_type](x).view(-1, self.heads, self.dim_per_head)

            query_dict[node_type] = query
            key_dict[node_type] = key
            value_dict[node_type] = value

        out_dict = {
        node_type: torch.zeros(
            x_dict[node_type].size(0), 
            self.heads * self.dim_per_head,  
            device=x_dict[node_type].device
        ) for node_type in x_dict.keys()
    }
        
        for edge_type, edge_index in edge_index_dict.items():
            src_type, rel_type, dst_type = edge_type
            edge_str = self.edge_type_to_str(edge_type)  
            
            edge_attr = edge_attr_dict.get(edge_type, None) if edge_attr_dict else None
            
            query = query_dict[dst_type][edge_index[1]]  
            key = key_dict[src_type][edge_index[0]]      
            att = torch.einsum('ehd,ehd->eh', query, key) 

            if edge_attr_dict and edge_type in edge_attr_dict:
                edge_attr = edge_attr_dict[edge_type]
                att = att + edge_attr  

            rel_att = self.a_rel[edge_str](key)
            att = att + rel_att.squeeze()

            att = torch.clamp(att, -5, 5)  
            att = torch.exp(att) 
            att_sum = self._softmax(att, edge_index[1])  

            msg = value_dict[src_type][edge_index[0]]  
            if edge_attr is not None:
                msg = msg * edge_attr.unsqueeze(-1)  
                
            msg = msg * att_sum.unsqueeze(-1) 
            
            aggregated = scatter(
                src=msg,          
                index=edge_index[1],  
                dim=0,            
                dim_size=x_dict[dst_type].size(0),  
                reduce="sum"      
            )
            
            merged = aggregated.view(-1, self.heads * self.dim_per_head)  
            out_dict[dst_type] += merged  

        for node_type in out_dict:
            out_dict[node_type] = self.a_lin[node_type](out_dict[node_type])
        
        return out_dict
    
    
class LinkPredictionModel(torch.nn.Module):
    def __init__(self, in_features, conv1_in, conv1_out, conv2_in, conv2_out, metadata):
        super().__init__()
        self.metadata = metadata
        
        self.encoders = torch.nn.ModuleDict({
            node_type: nn.Sequential(
                nn.Linear(in_features, conv1_in),
                nn.LeakyReLU(),
                nn.LayerNorm(conv1_in),
            )
            for node_type in metadata[0]
        })
        
        self.conv1 = WeightedHGTConv(
            in_channels=conv1_in,
            out_channels=conv1_out,
            metadata=metadata,
            heads=8
        )
        
        self.conv2 = WeightedHGTConv(
            in_channels=conv2_in,
            out_channels=conv2_out,
            metadata=metadata,
            heads=8
        )
        
        
        self.classifier = nn.Sequential(
            nn.Linear(conv2_out * 2, 1)
        )
        
        self.edge_weight_lin = nn.Linear(1, 1)
    
    def forward(self, x_dict, edge_index_dict, edge_weight_dict=None):

        x_dict = {node_type: self.encoders[node_type](x) 
                 for node_type, x in x_dict.items()}
        x_dict = self.conv1(x_dict, edge_index_dict, edge_weight_dict)
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict, edge_weight_dict)
        return x_dict
    

    def predict_link(self, src_emb, dst_emb, edge_weight=None):
        similarity = self.classifier(torch.cat([src_emb, dst_emb], dim=-1))
        return torch.sigmoid(similarity)
    
    def get_node_embeddings(self, data):

        self.eval()
        with torch.no_grad():
            edge_index_dict = {et: data[et].edge_index.to(device) for et in data.edge_types}
            edge_weight_dict = {et: data[et].edge_weight.to(device) for et in data.edge_types if 'edge_weight' in data[et]}
            
            x_dict = self(
                x_dict={nt: data[nt].x.to(device) for nt in data.node_types},
                edge_index_dict=edge_index_dict,
                edge_weight_dict=edge_weight_dict
            )
            
            return {nt: emb.cpu() for nt, emb in x_dict.items()}
