import torch
from torch_geometric.nn import pool
from torch.nn import ModuleList
from torch_geometric.nn import GENConv, MLP, Linear
from .stat import MMDLoss

class GenGeomAutoencoder(torch.nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        assert {'node_feature_dim', 'edge_feature_dim', 'latent_dim', 'z_dim', 'use_boundary_encoding', 'be_dim', 'use_pos', 'pos_dim'}.issubset(kwargs)

        additional_inputs  = kwargs['be_dim'] if kwargs['use_boundary_encoding'] else 0
        additional_inputs += kwargs['pos_dim'] if kwargs['use_pos'] else 0

        self.lifter_e = Lifter(kwargs['node_feature_dim']+additional_inputs, kwargs['edge_feature_dim'], kwargs['latent_dim'])
        self.encoder = GEN_E(kwargs['latent_dim'], kwargs['z_dim'])
        self.lifter_d = Lifter(kwargs['node_feature_dim']+additional_inputs, kwargs['edge_feature_dim'], kwargs['latent_dim'])
        self.decoder = GEN_D(kwargs['z_dim'], kwargs['latent_dim'], 
                             use_boundary_encoding=kwargs['use_boundary_encoding'], be_dim=kwargs['be_dim'],
                             use_pos=kwargs['use_pos'], pos_dim=kwargs['pos_dim'])
        self.projector = MLP(in_channels=kwargs['latent_dim'], hidden_channels=kwargs['latent_dim'],
                                out_channels=kwargs['node_feature_dim'], num_layers=2, 
                                act='relu', norm='layer', plain_last=True, dropout=0.0)
        
        self.z_dim = kwargs['z_dim']
        self.use_boundary_encoding = kwargs['use_boundary_encoding']
        self.use_pos = kwargs['use_pos']


    def forward(self, data):
        if self.use_boundary_encoding:
            data.x = torch.cat([data.x, data.boundary_encoding], dim=1)

        if self.use_pos:
            data.x = torch.cat([data.x, data.pos], dim=1)

        # Clone the data for later use in decoding
        data_clone = data.clone().detach()
            
        # Lift the input node and edge features to a higher-dimensional space
        data.x, data.edge_attr = self.lifter_e(data.x, data.edge_attr)
        
        # Encode the input data to latent representation
        data.z = self.encoder(data)
        
        #  Project the input edge attributes then decode
        _, data_clone.edge_attr = self.lifter_d(None, data_clone.edge_attr)
        data.x = self.decoder(data.z, data_clone)
        
        # Project back to original node feature space
        data.x = self.projector(data.x)  
        
        return data

    def compute_loss(self, data):
        mmd = MMDLoss(device=data.x.device)
        recon_loss = torch.mean((data.x - data.y)**2.)
        Xd = torch.randn_like(data.z).to(data.x.device)
        mmd_loss = mmd(data.z, Xd)
        loss = recon_loss + mmd_loss
        return loss, recon_loss, mmd_loss
    
    def decode(self, z_samples, data_batch):
        """
        Decode latent samples to node features.
        For GenGeomAutoencoder, this includes lifting edge features and projecting output.
        
        Args:
            z_samples: Latent samples [batch_size, z_dim]
            data_batch: Batch of graph data with edge information
            
        Returns:
            Decoded node features [num_nodes_total, node_feature_dim]
        """
        # Lift edge features if needed
        _, data_batch.edge_attr = self.lifter_d(None, data_batch.edge_attr)
        
        # Decode and project
        u_decoded = self.decoder(z_samples, data_batch)
        return self.projector(u_decoded)
    

class Lifter(torch.nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, latent_dim):
        super().__init__()
        # initialize the node, edge and global params projector MLPs
        self.node_projector = MLP(in_channels=node_feature_dim, 
                                hidden_channels=latent_dim, out_channels=latent_dim, 
                                num_layers=2, act='relu', norm='layer')
        self.edge_projector = MLP(in_channels=edge_feature_dim, hidden_channels=latent_dim,
                                out_channels=latent_dim, num_layers=2, 
                                act='relu', norm='layer')
        
        
    def forward(self, x, edge_attr):
        x_proj = self.node_projector(x) if x is not None else None
        edge_attr_proj = self.edge_projector(edge_attr) if edge_attr is not None else None

        return x_proj, edge_attr_proj


class GEN_E(torch.nn.Module):
    def __init__(self, latent_dim, z_dim):
        super(GEN_E, self).__init__()

        self.convs = ModuleList([
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
        ])
        
        self.to_z = Linear(latent_dim, z_dim)

    def forward(self, graph):
        x, edge_index, edge_attr = graph.x, graph.edge_index, graph.edge_attr
        
        for ctr, conv in enumerate(self.convs):
            x_global = pool.global_mean_pool(x, graph.batch)

            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]
            
            conv_input = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]
        
            x += conv(conv_input, edge_index, edge_attr=edge_attr)

        x = pool.global_mean_pool(x, graph.batch)
        
        # Project to z_dim
        x = self.to_z(x)
        
        return x


class GEN_D(torch.nn.Module):
    def __init__(self, z_dim, latent_dim, use_boundary_encoding=False, be_dim=3, use_pos=False, pos_dim=2):
        self.use_boundary_encoding = use_boundary_encoding
        self.use_pos = use_pos
        super(GEN_D, self).__init__()
        additional_inputs = be_dim if use_boundary_encoding else 0
        additional_inputs += pos_dim if use_pos else 0

        self.from_z = Linear(z_dim+additional_inputs, latent_dim)
        self.convs = ModuleList([
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
            GENConv(latent_dim*2, latent_dim, norm='layer'),
        ])
        
    def forward(self, z, graph):
        edge_index, edge_attr = graph.edge_index, graph.edge_attr
        
        x = z[graph.batch]  # [num_nodes, z_dim]
        # Concatenate with additional inputs
        if self.use_boundary_encoding:
            x = torch.cat([x, graph.boundary_encoding], dim=1)
        
        if self.use_pos:
            x = torch.cat([x, graph.pos], dim=1)
            
        # Project to match the latent dimension
        x = self.from_z(x)

        for i, conv in enumerate(self.convs):
            x_global = pool.global_mean_pool(x, graph.batch)
            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]
            
            conv_input = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]
            x += conv(conv_input, edge_index, edge_attr=edge_attr)

        return x
    