import math
from typing import Dict, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn import HeteroConv
from torch_geometric.data import HeteroData
from .sage import SAGEConv


class HetEncoder(nn.Module):
    def __init__(self,
                 metadata: Tuple[list, list],
                 in_dims: Dict[str, int],
                 hid_dim: int):
        super().__init__()
        self.metadata = metadata
        self.in_dims = in_dims
        self.hid_dim = hid_dim

        # Per node-type input projection
        self.lin_ntype = nn.ModuleDict()
        for ntype in metadata[0]:
            self.lin_ntype[ntype] = nn.Linear(in_dims[ntype], hid_dim)

        # Per relation GraphSAGE conv, aggregate='mean'
        convs = {}
        for rel in metadata[1]:
            convs[rel] = SAGEConv(hid_dim, hid_dim, aggr='mean', temporal=False, etypes=1)
        self.hetero_conv = HeteroConv(convs, aggr='sum')

    def forward(self,
                x_dict: Dict[str, Tensor],
                edge_index_dict: Dict[Tuple[str, str, str], Tensor]) -> Dict[str, Tensor]:
        # Project inputs per node type
        hx = {ntype: self.lin_ntype[ntype](x) for ntype, x in x_dict.items()}
        # Hetero convolution
        hx = self.hetero_conv(hx, edge_index_dict)
        # Activation
        hx = {k: F.relu(v) for k, v in hx.items()}
        return hx


class MultiViewNodeEncoder(nn.Module):
    
    def __init__(self, in_dim: int, hid_dim: int, num_relations: int):
        super().__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.num_relations = num_relations
        
        
        self.lin_in = nn.Linear(in_dim, hid_dim)
        
        
        self.convs_per_rel = nn.ModuleList([
            SAGEConv(hid_dim, hid_dim, aggr='mean', temporal=False, etypes=1)
            for _ in range(num_relations)
        ])
        
        
        self.relation_aggr = nn.Linear(num_relations * hid_dim, hid_dim)
    
    def forward(self, x: Tensor, edge_indices_list: list) -> Tensor:
        
        
        h = F.relu(self.lin_in(x))  # [N, hid_dim]
        
        
        h_per_rel = []
        for conv, edge_index in zip(self.convs_per_rel, edge_indices_list):
            h_r = conv(h, edge_index)  # [N, hid_dim]
            h_per_rel.append(h_r)
        
        
        h_concat = torch.cat(h_per_rel, dim=-1)  # [N, V*hid_dim]
        
        
        h_fused = self.relation_aggr(h_concat)  # [N, hid_dim]
        h_fused = F.relu(h_fused)
        
        return h_fused


class JointNodeVGAE(nn.Module):
    
    def __init__(self, in_dim: int, hid_dim: int, num_relations: int, relation_names: list):
        super().__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.num_relations = num_relations
        self.relation_names = relation_names
        
        
        self.encoder = MultiViewNodeEncoder(in_dim, hid_dim, num_relations)
        
       
        self.enc_mu = nn.Linear(hid_dim, hid_dim)
        self.enc_logstd = nn.Linear(hid_dim, hid_dim)
        
        
        self.dec_attr = nn.Linear(hid_dim, in_dim)
        
        
        self.dec_edges = nn.ModuleDict({
            rel_name: nn.Linear(2 * hid_dim, 1)
            for rel_name in relation_names
        })
        
        
        self.map_label_e = nn.Linear(1, in_dim, bias=False)
        self.map_label_d = nn.Linear(1, hid_dim, bias=False)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def encode(self, x: Tensor, edge_indices_list: list, label: Optional[Tensor] = None):
        
        if label is not None:
            x = x + self.map_label_e(label)
        
        h = self.encoder(x, edge_indices_list)  # [N, hid_dim]
        mu = self.enc_mu(h)  # [N, hid_dim]
        logstd = self.enc_logstd(h)  # [N, hid_dim]
        
        return mu, logstd
    
    def reparameterize(self, mu: Tensor, logstd: Tensor) -> Tensor:
        
        eps = torch.randn_like(mu)
        return mu + eps * torch.exp(logstd)
    
    def decode_node_attr(self, z: Tensor) -> Tensor:
        
        return self.dec_attr(z)  # [N, in_dim]
    
    def decode_edges(self, z: Tensor, edge_index: Tensor, relation_name: str) -> Tensor:
        
        z_src = z[edge_index[0]]  # [E, hid_dim]
        z_dst = z[edge_index[1]]  # [E, hid_dim]
        z_edge = torch.cat([z_src, z_dst], dim=-1)  # [E, 2*hid_dim]
        logits = self.dec_edges[relation_name](z_edge).squeeze(-1)  # [E]
        return logits
    
    def sample(self, num_nodes: int, label: Optional[Tensor] = None):
        
        device = next(self.parameters()).device
        z = torch.randn(num_nodes, self.hid_dim, device=device)
        if label is not None:
            z = z + self.map_label_d(label.to(device))
        x_new = self.decode_node_attr(z)
        return x_new, z


class HetVGAE(nn.Module):
    """Variational Graph Autoencoder for HeteroData.
    - Encoder: HetEncoder (HeteroConv with per-relation SAGEConv)
    - Mean/LogStd heads: per node-type Linear
    - Decoder:
      - Node attributes: per node-type Linear(hid->in_dim)
      - Structure: per relation Linear(2*hid->1)
    """
    def __init__(self,
                 metadata: Tuple[list, list],
                 in_dims: Dict[str, int],
                 hid_dim: int,
                 target_ntype: str):
        super().__init__()
        self.metadata = metadata
        self.in_dims = in_dims
        self.hid_dim = hid_dim
        self.target_ntype = target_ntype

        self.encoder = HetEncoder(metadata, in_dims, hid_dim)

        # Per node-type heads
        self.enc_mu = nn.ModuleDict({nt: nn.Linear(hid_dim, hid_dim) for nt in metadata[0]})
        self.enc_logstd = nn.ModuleDict({nt: nn.Linear(hid_dim, hid_dim) for nt in metadata[0]})
        self.dec_attr = nn.ModuleDict({nt: nn.Linear(hid_dim, in_dims[nt]) for nt in metadata[0]})

        # Per relation decoder for structure
        self.dec_stru = nn.ModuleDict({rel[1]: nn.Linear(2 * hid_dim, 1) for rel in metadata[1]})

        # Label conditioning (optional, keep interface)
        self.map_label_e = nn.Linear(1, in_dims[target_ntype], bias=False)
        self.map_label_d = nn.Linear(1, hid_dim, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def encode(self, data: HeteroData, label: Optional[Tensor] = None):
        x_dict = data.x_dict
        if label is not None and self.target_ntype in x_dict:
            x_dict = dict(x_dict)
            x_dict[self.target_ntype] = x_dict[self.target_ntype] + self.map_label_e(label)

        h_dict = self.encoder(x_dict, data.edge_index_dict)
        mu = {nt: self.enc_mu[nt](h) for nt, h in h_dict.items()}
        logstd = {nt: self.enc_logstd[nt](h) for nt, h in h_dict.items()}
        return mu, logstd

    def reparameterize(self, mu_t: Tensor, logstd_t: Tensor) -> Tensor:
        eps = torch.randn_like(mu_t)
        return mu_t + eps * torch.exp(logstd_t)

    def decode_attr(self, z_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
        return {nt: self.dec_attr[nt](z) for nt, z in z_dict.items() if nt in self.dec_attr}

    def decode_stru(self, z_dict: Dict[str, Tensor],
                    pos_edge_index_dict: Dict[Tuple[str, str, str], Tensor]) -> Dict[str, Tensor]:
        preds = {}
        for (st, rel, dt), ei in pos_edge_index_dict.items():
            if st not in z_dict or dt not in z_dict:
                continue
            z_src = z_dict[st]
            z_dst = z_dict[dt]
            ze = torch.cat([z_src[ei[0]], z_dst[ei[1]]], dim=-1)
            preds[rel] = self.dec_stru[rel](ze).squeeze(-1)
        return preds

    def forward(self, data: HeteroData, label: Optional[Tensor] = None):
        mu, logstd = self.encode(data, label)
        z = {nt: self.reparameterize(mu[nt], logstd[nt]) for nt in mu}
        x_rec = self.decode_attr(z)
        edge_logits = self.decode_stru(z, data.edge_index_dict)
        self.mu = mu
        self.logstd = logstd
        return x_rec, edge_logits

    def sample(self, num_nodes_target: int, label: Optional[Tensor] = None):
        device = next(self.parameters()).device
        z_t = torch.randn(num_nodes_target, self.hid_dim, device=device)
        if label is not None:
            z_t = z_t + self.map_label_d(label.to(device))
        x_new = self.dec_attr[self.target_ntype](z_t)
        return x_new, z_t


