import torch
from torch import nn

from ..flow_matching_model import FlowMatchingModel
from ...blocks import SinusoidalTimeEmbedding
from ...models.multi_scale_gnn import *
from ....graph import Graph


class FiLM(nn.Module):
    def __init__(self, emb_width: int, hidden_dim: int):
        super().__init__()
        self.film_net = nn.Sequential(
            nn.Linear(emb_width, hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, hidden_dim * 2), 
        )
    def forward(self, remb):  # remb: (batch_size, emb_width)
        gamma_beta = self.film_net(remb)  # (batch_size, 2 * hidden_dim)
        gamma, beta = gamma_beta.chunk(2, dim=-1)  # (batch_size, hidden_dim) 
        return gamma, beta


class ConditionAwareGraphFlowMatching(FlowMatchingModel):
    r"""Approximate the the graph-conditioned vector field.

    Args:
        arch (dict): Dictionary with the architecture of the model. It must contain the following keys:
            - 'in_node_features' (int): Number of input node features. This is the number of features of the noisy field.
            - 'cond_node_features' (int, optional): Number of conditional node features. Defaults to 0.
            - 'cond_edge_features' (int, optional): Number of conditional edge features. Defaults to 0.
            - 'depths' (list): List of integers with the number of layers at each depth.
            - 'fnns_depth' (int, optional): Number of layers in the FNNs. Defaults to 2.
            - 'fnns_width' (int): Width of the FNNs.
            - 'aggr' (str, optional): Aggregation method. Defaults to 'mean'.
            - 'dropout' (float, optional): Dropout probability. Defaults to 0.0.
            - 'emb_width' (int, optional): Width of the r embedding. Defaults to 4 * fnns_width.
            - 'dim' (int, optional): Dimension of the latent space. Defaults to 2.
            - 'scalar_rel_pos' (bool, optional): Whether to use scalar relative positions. Defaults to True.
        """

    def __init__(self, *args, **kwargs):  
        super().__init__(*args, **kwargs)

    def load_arch(self, arch: dict):
        self.arch = arch
        # Hyperparameters
        self.in_node_features   = arch['in_node_features']
        self.cond_node_features = arch.get('cond_node_features', 0)
        self.cond_edge_features = arch.get('cond_edge_features', 0)
        self.depths             = arch['depths']
        self.fnns_depth         = arch.get('fnns_depth', 2)
        self.fnns_width         = arch['fnns_width']
        self.aggr               = arch.get('aggr', 'mean')
        self.dropout            = arch.get('dropout', 0.0)
        self.emb_width          = arch.get('emb_width', self.fnns_width * 4)
        self.dim                = arch.get('dim', 2)
        self.scalar_rel_pos     = arch.get('scalar_rel_pos', True)
        if 'in_edge_features' in arch: # To support backward compatibility
             self.cond_edge_features = arch['in_edge_features'] + self.cond_edge_features
        # Validate the inputs
        assert self.in_node_features > 0, "Input node features must be a positive integer"
        assert self.cond_node_features >= 0, "Condition features must be a non-negative integer"
        assert len(self.depths) > 0, "Depths (`depths`) must be a list of integers"
        assert isinstance(self.depths, list), "Depths (`depths`) must be a list of integers"
        assert all([isinstance(depth, int) for depth in self.depths]), "Depths (`depths`) must be a list of integers"
        assert all([depth > 0 for depth in self.depths]), "Depths (`depths`) must be a list of positive integers"
        assert self.fnns_depth >=2 , "FNNs depth (`fnns_depth`) must be at least 2"
        assert self.fnns_width > 0, "FNNs width (`fnns_width`) must be a positive integer"
        assert self.aggr in ('mean', 'sum'), "Aggregation method (`aggr`) must be either 'mean' or 'sum'"
        assert self.dropout >= 0.0 and self.dropout < 1.0, "Dropout (`dropout`) must be a float between 0.0 and 1.0"
        self.out_node_features = self.in_node_features
        # r embedding 
        self.r_embedding = nn.Sequential(
            SinusoidalTimeEmbedding(self.fnns_width),
            nn.Linear(self.fnns_width, self.emb_width),
            nn.SELU(),
        )
        self.node_encoder = nn.Sequential(
            nn.Linear(self.in_node_features + self.cond_node_features, self.fnns_width * 2),      
            nn.SELU(),                                       
            nn.Linear(self.fnns_width * 2, self.fnns_width), 
        )

        self.film = FiLM(
            emb_width=self.emb_width,
            hidden_dim=self.fnns_width,  
        )

        # Edge encoder
        self.edge_encoder = nn.Sequential(
            nn.Linear(self.cond_edge_features, self.fnns_width * 2),      
            nn.SELU(),                                       
            nn.Linear(self.fnns_width * 2, self.fnns_width),    
        )

        self.propagator = MultiScaleGnn(
            depths            = self.depths,
            fnns_depth        = self.fnns_depth,
            fnns_width        = self.fnns_width,
            emb_features      = self.emb_width,
            aggr              = self.aggr,
            activation        = nn.SELU,
            dropout           = self.dropout,
            dim               = self.dim,
            scalar_rel_pos    = self.scalar_rel_pos,
        )

        self.node_decoder = nn.Sequential(
            nn.Linear(self.fnns_width, self.fnns_width * 2),      
            nn.SELU(),                                       
            nn.Linear(self.fnns_width * 2, self.out_node_features), 
        )

 
    @property
    def num_fields(self) -> int:
        return self.out_node_features
    
    def reset_parameters(self):
            modules = [module for module in self.children() if hasattr(module, 'reset_parameters')]
            for module in modules:
                module.reset_parameters()

    def forward(
        self,
        graph: Graph,
    ) -> torch.Tensor:
        assert hasattr(graph, 'r'), "graph must have an attribute 'r'"
        assert hasattr(graph, 'field_r'), "graph must have an attribute 'field_r'"
        # Embed r  
        emb = self.r_embedding(graph.r) # Shape (batch_size, emb_width)
        # Encode the node features
        v = self.node_encoder(
            torch.cat([
                graph.field_r,
                *[f for f in [graph.get('cond'), graph.get('field'), graph.get('loc'), graph.get('glob'), graph.get('omega')] if f is not None],
            ], dim=1)
        ) # Shape (num_nodes, fnns_width)

        gamma, beta = self.film(emb)  # Shape: (batch_size, fnns_width)
        v = gamma[graph.batch] * v + beta[graph.batch]  

        # Encode the edge irreps/features
        e = self.edge_encoder(
            torch.cat([
                graph.edge_attr,
                *[f for f in [graph.get('edge_cond')] if f is not None],
            ], dim=1)
        )
        # Propagate the scalar latent space (conditioned on c)
        v, _ = self.propagator(graph, v, e, emb)
        # Decode the latent node features
        return self.node_decoder(v)
