#!/usr/bin/env python3
"""
MeshGraphNets Baseline Models

This module implements MeshGraphNet and MultiScaleMeshGraphNet as baseline models
for comparison with the GenSynth diffusion model. These models use direct supervised
learning instead of diffusion, taking forcings and guiding inputs to directly predict
the target velocity field.

Key differences from GenSynth:
- No diffusion process (no noise levels)
- Direct prediction from inputs to targets
- Uses same graph network architectures as the denoiser but simplified

MeshGraphNet: Uses only the original→original network (ori2ori_gnn)
MultiScaleMeshGraphNet: Uses all four networks like the full denoiser
"""

import dataclasses
from typing import Any, Dict, Optional, Sequence

import chex
import flax.nnx as nnx
import jax
import jax.numpy as jnp

from gen_da import deep_typed_graph_net, typed_graph
from gen_da.denoiser import DenoiserArchitectureConfig


class MeshGraphNet(nnx.Module):
    """
    MeshGraphNet baseline using only the original→original graph network.
    
    This is a simplified version that directly predicts the target field from
    the input features using a single DeepTypedGraphNet operating on the original mesh.
    """
    
    def __init__(self, 
                 config: DenoiserArchitectureConfig,
                 rngs: nnx.Rngs,
                 mesh,
                 example_graph_structures: Dict[str, jnp.ndarray],
                 target_channels: int,
                 guiding_channels: int):
        """
        Initialize MeshGraphNet.
        
        Args:
            config: Configuration for the architecture
            rngs: Random number generators
            mesh: Mesh object (for compatibility)
            example_graph_structures: Example graph structure for initialization
            target_channels: Number of target field channels
            guiding_channels: Number of guiding field channels
        """
        self.cfg = config
        self.rngs = rngs
        self.mesh = mesh
        self.target_channels = target_channels
        self.guiding_channels = guiding_channels
        
        # Edge keys
        self.oo_key = typed_graph.EdgeSetKey("o2o", ("orig", "orig"))
        
        # Get feature dimensions from template
        tmpl = example_graph_structures
        orig_struct_dim = tmpl["original_coordinates"].shape[1]

        base_o2o = tmpl.get("o2o_features", None)
        base_o2o_dim = base_o2o.shape[1] if base_o2o is not None else 0
        # We will always append 2 directional features; ensure template matches runtime
        o2o_edge_dim = (base_o2o_dim + 2) if base_o2o_dim > 0 else 2
        
        # Input features: guiding + obs_mask + coordinates + node_types + [cosθ,sinθ]
        ori2ori_in_dim = guiding_channels + 1 + orig_struct_dim + 3 + 2
        
        # Create graph template
        def zeros_nf(d): 
            return jnp.zeros((1, 1, d), jnp.float32)
            
        graph_o2o_t = typed_graph.TypedGraph(
            context=typed_graph.Context(n_graph=jnp.array([1]), features=None),
            nodes={"orig": typed_graph.NodeSet(
                n_node=jnp.array([1]), 
                features=zeros_nf(ori2ori_in_dim)
            )},
            edges={self.oo_key: typed_graph.EdgeSet(
                n_edge=jnp.array([1]),
                indices=typed_graph.EdgesIndices(senders=jnp.array([0]), receivers=jnp.array([0])),
                features=zeros_nf(o2o_edge_dim),
            )}
        )
        
        self.graph_o2o_template = graph_o2o_t
        
        # Initialize the original→original GNN
        print("Init MeshGraphNet Original→Original GNN")
        self.ori2ori_gnn = deep_typed_graph_net.DeepTypedGraphNet(
            activation="swish",
            edge_latent_size={"o2o": config.latent_size},
            embed_edges=True, embed_nodes=True,
            f32_aggregation=True, include_sent_messages_in_node_update=False,
            mlp_hidden_size=config.latent_size, mlp_num_hidden_layers=config.hidden_layers,
            node_latent_size={"orig": config.latent_size},
            node_output_size={"orig": config.node_output_size or target_channels},
            num_message_passing_steps=4,
            use_layer_norm=True, use_norm_conditioning=False,  # No noise conditioning
            rngs=rngs, mesh=mesh,
            graph_template=graph_o2o_t,
        )

    def __call__(self, forcings: Dict[str, Any]) -> jnp.ndarray:
        """
        Forward pass of MeshGraphNet.
        
        Args:
            forcings: Dictionary containing forcings and graph structures
            
        Returns:
            Predicted velocity field (B, N_o, C_t)
        """
        guiding = jnp.asarray(forcings.get("U_field_guiding", 0.0))  # (B, N_o, C_g)
        B, N_o, C_g = guiding.shape
        obs_mask = jnp.asarray(forcings.get("obs_mask", jnp.zeros((B, N_o), jnp.float32)))  # (B,N_o)
        angle_deg = jnp.asarray(forcings.get("angle_deg", jnp.zeros((B,), jnp.int32)), jnp.float32)
        theta = angle_deg * (jnp.pi / 180.0)
        d = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)  # (B,2)
        
        # Get graph structures
        gstructs = forcings["graph_structures"]
        if isinstance(gstructs, dict):
            gstructs = [gstructs] * B
        self._ensure_same_shapes(gstructs)
        g = gstructs[0]
        
        # Extract coordinates and node types
        orig_xy = jnp.asarray(g["original_coordinates"], jnp.float32)
        node_types_int = jnp.asarray(g["node_types"], jnp.int32)
        node_types_oh = jax.nn.one_hot(node_types_int, 3, dtype=jnp.float32)
        
        # Edge information
        oo_s = jnp.asarray(g["o2o_senders"], jnp.int32)
        oo_r = jnp.asarray(g["o2o_receivers"], jnp.int32)
        oo_f = jnp.asarray(g.get("o2o_features", None), jnp.float32) if g.get("o2o_features", None) is not None else None
        
        # Broadcast static features to batch dimension
        orig_struct = self._broadcast(orig_xy, B)  # (N_o, B, 2)
        node_types_static = self._broadcast(node_types_oh, B)  # (N_o, B, 3)
        # Directional edge features wrt wind direction
        rel = orig_xy[oo_r] - orig_xy[oo_s]  # (E,2)
        length = jnp.linalg.norm(rel, axis=-1, keepdims=True) + 1e-6
        p_par = jnp.einsum("ed,bd->eb", rel, d)      # (E,B)
        p_perp = jnp.einsum("ed,bd->eb", rel, jnp.stack([-d[:,1], d[:,0]], axis=-1))  # (E,B)
        dir_edge = jnp.stack([p_par/length, p_perp/length], axis=-1)  # (E,B,2)
        base_edges = self._broadcast(oo_f, B) if oo_f is not None else None
        o2o_edges = dir_edge if base_edges is None else jnp.concatenate([base_edges, dir_edge], axis=-1)
            
            # Prepare node features: transpose to (N_o, B, features)
            # Only use guiding features, not target inputs
        prev_feat = guiding.transpose(1, 0, 2).astype(jnp.float32)  # (N_o, B, C_g)
        obs_feat = obs_mask.transpose(1, 0)[..., None].astype(jnp.float32)  # (N_o,B,1)
        dir_nodes = jnp.broadcast_to(d[None, :, :], (N_o, B, 2))
        
        # Combine all features (no target inputs, no case encoding)
        combined = jnp.concatenate([
            prev_feat, obs_feat, orig_struct, node_types_static, dir_nodes
        ], axis=-1)  # (N_o, B, total_features)
        
        # Create graph
        graph = self.graph_o2o_template._replace(
            nodes={"orig": typed_graph.NodeSet(n_node=jnp.array([N_o]), features=combined)},
            edges={self.oo_key: typed_graph.EdgeSet(
                n_edge=jnp.array([oo_s.size]),
                indices=typed_graph.EdgesIndices(senders=oo_s, receivers=oo_r),
                features=o2o_edges,
            )}
        )
        
        # Forward pass through GNN (no global conditioning for direct prediction)
        out = self.ori2ori_gnn(graph)
        pred = out.nodes["orig"].features.astype(jnp.float32)  # (N_o, B, C_out)
        
        return pred.transpose(1, 0, 2)  # (B, N_o, C_out)

    @staticmethod
    def _broadcast(arr: jnp.ndarray, B: int) -> jnp.ndarray:
        return jnp.broadcast_to(arr[:, None, :], (arr.shape[0], B, arr.shape[1]))

    @staticmethod
    def _ensure_same_shapes(graphs: Sequence[Dict[str, jnp.ndarray]]):
        base = graphs[0]
        for g in graphs[1:]:
            for k, v in base.items():
                if v.shape != g[k].shape:
                    raise AssertionError(
                        f"Batched samples must share the same topology. Mismatch in '{k}'."
                    )


class MultiScaleMeshGraphNet(nnx.Module):
    """
    MultiScaleMeshGraphNet baseline using all four graph networks like the full denoiser.
    
    This uses the complete multi-scale architecture: original→original, original→reduced,
    reduced→reduced, and reduced→original, but without diffusion (direct prediction).
    """
    
    def __init__(self, 
                 config: DenoiserArchitectureConfig,
                 rngs: nnx.Rngs,
                 mesh,
                 example_graph_structures: Dict[str, jnp.ndarray],
                 target_channels: int,
                 guiding_channels: int):
        """
        Initialize MultiScaleMeshGraphNet.
        
        Args:
            config: Configuration for the architecture
            rngs: Random number generators
            mesh: Mesh object
            example_graph_structures: Example graph structure for initialization
            target_channels: Number of target field channels
            guiding_channels: Number of guiding field channels
        """
        self.cfg = config
        self.rngs = rngs
        self.mesh = mesh
        self.target_channels = target_channels
        self.guiding_channels = guiding_channels
        
        # Edge keys
        self.o2r_key = typed_graph.EdgeSetKey("o2r", ("orig", "red"))
        self.rr_key = typed_graph.EdgeSetKey("r2r", ("red", "red"))
        self.oo_key = typed_graph.EdgeSetKey("o2o", ("orig", "orig"))
        self.r2o_key = typed_graph.EdgeSetKey("r2o", ("red", "orig"))
        
        # Feature dims from template
        tmpl = example_graph_structures
        orig_struct_dim = tmpl["original_coordinates"].shape[1]
        red_struct_dim = tmpl["reduced_coordinates"].shape[1]

        base_o2o = tmpl.get("o2o_features", None)
        base_o2o_dim = base_o2o.shape[1] if base_o2o is not None else 0
        # Include +2 directional features in template
        o2o_edge_dim = (base_o2o_dim + 2) if base_o2o_dim > 0 else 2
        o2r_edge_dim = tmpl["o2r_features"].shape[1]
        r2r_edge_dim = tmpl.get("r2r_features", None)
        r2r_edge_dim = r2r_edge_dim.shape[1] if r2r_edge_dim is not None else 0
        r2o_edge_dim = tmpl["r2o_features"].shape[1]
        
        cfg = self.cfg
        C_t = target_channels
        C_g = guiding_channels
        node_out = cfg.node_output_size or C_t
        
        # Graph templates
        def zeros_nf(d): 
            return jnp.zeros((1, 1, d), jnp.float32)
            
        ori2ori_in_dim = C_g + orig_struct_dim + 3  + 1 + 2  # guiding + struct + node_types + obs_mask + [cosθ,sinθ]
        
        graph_o2o_t = typed_graph.TypedGraph(
            context=typed_graph.Context(n_graph=jnp.array([1]), features=None),
            nodes={"orig": typed_graph.NodeSet(n_node=jnp.array([1]), features=zeros_nf(ori2ori_in_dim))},
            edges={self.oo_key: typed_graph.EdgeSet(
                n_edge=jnp.array([1]),
                indices=typed_graph.EdgesIndices(senders=jnp.array([0]), receivers=jnp.array([0])),
                features=zeros_nf(o2o_edge_dim),
            )}
        )
        
        graph_o2r_t = typed_graph.TypedGraph(
            context=typed_graph.Context(n_graph=jnp.array([1]), features=None),
            nodes={
                "orig": typed_graph.NodeSet(n_node=jnp.array([1]), features=zeros_nf(cfg.latent_size)),
                "red": typed_graph.NodeSet(n_node=jnp.array([1]), features=zeros_nf(red_struct_dim)),
            },
            edges={self.o2r_key: typed_graph.EdgeSet(
                n_edge=jnp.array([1]),
                indices=typed_graph.EdgesIndices(senders=jnp.array([0]), receivers=jnp.array([0])),
                features=zeros_nf(o2r_edge_dim),
            )}
        )
        
        graph_r2r_t = typed_graph.TypedGraph(
            context=typed_graph.Context(n_graph=jnp.array([1]), features=None),
            nodes={"red": typed_graph.NodeSet(
                n_node=jnp.array([1]), features=zeros_nf(cfg.latent_size)
            )},
            edges={self.rr_key: typed_graph.EdgeSet(
                n_edge=jnp.array([1]),
                indices=typed_graph.EdgesIndices(senders=jnp.array([0]), receivers=jnp.array([0])),
                features=zeros_nf(r2r_edge_dim) if r2r_edge_dim > 0 else None,
            )}
        )
        
        graph_r2o_t = typed_graph.TypedGraph(
            context=typed_graph.Context(n_graph=jnp.array([1]), features=None),
            nodes={
                "red": typed_graph.NodeSet(n_node=jnp.array([1]), features=zeros_nf(cfg.latent_size)),
                "orig": typed_graph.NodeSet(n_node=jnp.array([1]), features=zeros_nf(cfg.latent_size)),
            },
            edges={self.r2o_key: typed_graph.EdgeSet(
                n_edge=jnp.array([1]),
                indices=typed_graph.EdgesIndices(senders=jnp.array([0]), receivers=jnp.array([0])),
                features=zeros_nf(r2o_edge_dim),
            )}
        )
        
        # Store templates
        self.graph_o2o_template = graph_o2o_t
        self.graph_o2r_template = graph_o2r_t
        self.graph_r2r_template = graph_r2r_t
        self.graph_r2o_template = graph_r2o_t
        
        # Initialize all four GNN networks
        print("Init MultiScale Original→Original GNN")
        self.ori2ori_gnn = deep_typed_graph_net.DeepTypedGraphNet(
            activation="swish",
            edge_latent_size={"o2o": cfg.latent_size},
            embed_edges=True, embed_nodes=True,
            f32_aggregation=True, include_sent_messages_in_node_update=False,
            mlp_hidden_size=cfg.latent_size, mlp_num_hidden_layers=cfg.hidden_layers,
            node_latent_size={"orig": cfg.latent_size},
            node_output_size=None,
            num_message_passing_steps=2,
            use_layer_norm=True, use_norm_conditioning=False,
            rngs=rngs, mesh=mesh,
            graph_template=graph_o2o_t,
        )
        
        print("Init MultiScale Original→Reduced GNN")
        self.ori2red_gnn = deep_typed_graph_net.DeepTypedGraphNet(
            activation="swish",
            aggregate_normalization=cfg.grid2mesh_aggregate_normalization,
            edge_latent_size={"o2r": cfg.latent_size},
            embed_edges=True, embed_nodes=True,
            f32_aggregation=True, include_sent_messages_in_node_update=False,
            mlp_hidden_size=cfg.latent_size, mlp_num_hidden_layers=cfg.hidden_layers,
            node_latent_size={"orig": cfg.latent_size, "red": cfg.latent_size},
            node_output_size=None,
            num_message_passing_steps=1,
            use_layer_norm=True, use_norm_conditioning=False,
            rngs=rngs, mesh=mesh,
            graph_template=graph_o2r_t,
        )
        
        print("Init MultiScale Reduced→Reduced GNN")
        self.red2red_gnn = deep_typed_graph_net.DeepTypedGraphNet(
            activation="swish",
            edge_latent_size={"r2r": cfg.latent_size},
            embed_edges=True, embed_nodes=False,
            f32_aggregation=True, include_sent_messages_in_node_update=False,
            mlp_hidden_size=cfg.latent_size, mlp_num_hidden_layers=cfg.hidden_layers,
            node_latent_size={"red": cfg.latent_size},
            node_output_size=None,
            num_message_passing_steps=6,
            use_layer_norm=True, use_norm_conditioning=False,
            rngs=rngs, mesh=mesh,
            graph_template=graph_r2r_t,
        )
        
        print("Init MultiScale Reduced→Original GNN")
        self.red2ori_gnn = deep_typed_graph_net.DeepTypedGraphNet(
            activation="swish",
            edge_latent_size={"r2o": cfg.latent_size},
            embed_edges=True, embed_nodes=False,
            f32_aggregation=True, include_sent_messages_in_node_update=False,
            mlp_hidden_size=cfg.latent_size, mlp_num_hidden_layers=cfg.hidden_layers,
            node_latent_size={"red": cfg.latent_size, "orig": cfg.latent_size},
            node_output_size={"orig": node_out},
            num_message_passing_steps=1,
            use_layer_norm=True, use_norm_conditioning=False,
            rngs=rngs, mesh=mesh,
            graph_template=graph_r2o_t,
        )

    def __call__(self, forcings: Dict[str, Any]) -> jnp.ndarray:
        """
        Forward pass of MultiScaleMeshGraphNet.
        
        Args:
            forcings: Dictionary containing forcings and graph structures
            
        Returns:
            Predicted velocity field (B, N_o, C_t)
        """
        guiding = jnp.asarray(forcings.get("U_field_guiding", 0.0))  # (B, N_o, C_g)
        B, N_o, C_g = guiding.shape
        obs_mask = jnp.asarray(forcings.get("obs_mask", jnp.zeros((B, N_o), jnp.float32)))  # (B,N_o)
        angle_deg = jnp.asarray(forcings.get("angle_deg", jnp.zeros((B,), jnp.int32)), jnp.float32)
        theta = angle_deg * (jnp.pi / 180.0)
        d = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)  # (B,2)
        
        # Get graph structures
        gstructs = forcings["graph_structures"]
        if isinstance(gstructs, dict):
            gstructs = [gstructs] * B
        self._ensure_same_shapes(gstructs)
        g = gstructs[0]
        
        # Extract all graph components
        orig_xy = jnp.asarray(g["original_coordinates"], jnp.float32)
        red_xy = jnp.asarray(g["reduced_coordinates"], jnp.float32)
        node_types_int = jnp.asarray(g["node_types"], jnp.int32)
        node_types_oh = jax.nn.one_hot(node_types_int, 3, dtype=jnp.float32)
        
        # Edge information
        oo_s = jnp.asarray(g["o2o_senders"], jnp.int32)
        oo_r = jnp.asarray(g["o2o_receivers"], jnp.int32)
        oo_f = jnp.asarray(g.get("o2o_features", None), jnp.float32) if g.get("o2o_features", None) is not None else None
        
        o2r_s = jnp.asarray(g["o2r_senders"], jnp.int32)
        o2r_r = jnp.asarray(g["o2r_receivers"], jnp.int32)
        o2r_f = jnp.asarray(g["o2r_features"], jnp.float32)
        
        rr_s = jnp.asarray(g["r2r_senders"], jnp.int32)
        rr_r = jnp.asarray(g["r2r_receivers"], jnp.int32)
        rr_f = jnp.asarray(g.get("r2r_features", None), jnp.float32) if g.get("r2r_features", None) is not None else None
        
        r2o_s = jnp.asarray(g["r2o_senders"], jnp.int32)
        r2o_r = jnp.asarray(g["r2o_receivers"], jnp.int32)
        r2o_f = jnp.asarray(g["r2o_features"], jnp.float32)
        
        N_o_struct = orig_xy.shape[0]
        if N_o_struct != N_o:
            raise ValueError(f"Input CFD nodes ({N_o}) != original_coordinates ({N_o_struct})")
        
        # Broadcast static features
        stat = {
            "orig_struct": self._broadcast(orig_xy, B),
            "red_struct": self._broadcast(red_xy, B),
            "node_types": self._broadcast(node_types_oh, B),
            "o2o_edges": self._broadcast(oo_f, B) if oo_f is not None else None,
            "o2r_edges": self._broadcast(o2r_f, B),
            "r2r_edges": self._broadcast(rr_f, B) if rr_f is not None else None,
            "r2o_edges": self._broadcast(r2o_f, B),
        }
        
        # Prepare node features
        # Only use guiding features, not target inputs
        prev_feat = guiding.transpose(1, 0, 2).astype(jnp.float32)
        obs_feat = obs_mask.transpose(1, 0)[..., None].astype(jnp.float32)
        dir_nodes = jnp.broadcast_to(d[None, :, :], (N_o, B, 2))
        
        # Combine features for original→original (no target inputs, no case encoding)
        combined = jnp.concatenate([
            prev_feat, obs_feat, stat["orig_struct"], stat["node_types"], dir_nodes
        ], axis=-1)  # (N_o, B, C_g+1+struct+3+2)
        
        # ---------------- Original → Original ---------------- #
        # Directional o2o edges
        rel = orig_xy[oo_r] - orig_xy[oo_s]
        length = jnp.linalg.norm(rel, axis=-1, keepdims=True) + 1e-6
        p_par = jnp.einsum("ed,bd->eb", rel, d)
        p_perp = jnp.einsum("ed,bd->eb", rel, jnp.stack([-d[:,1], d[:,0]], axis=-1))
        dir_edge = jnp.stack([p_par/length, p_perp/length], axis=-1)
        base_edges = stat["o2o_edges"]
        o2o_edges = dir_edge if base_edges is None else jnp.concatenate([base_edges, dir_edge], axis=-1)

        graph1 = self.graph_o2o_template._replace(
            nodes={"orig": typed_graph.NodeSet(n_node=jnp.array([N_o]), features=combined)},
            edges={self.oo_key: typed_graph.EdgeSet(
                n_edge=jnp.array([oo_s.size]),
                indices=typed_graph.EdgesIndices(senders=oo_s, receivers=oo_r),
                features=o2o_edges,
            )}
        )
        
        out_oo = self.ori2ori_gnn(graph1)
        updated_orig = out_oo.nodes["orig"].features.astype(jnp.float32)  # (N_o, B, L)
        
        # ---------------- Original → Reduced ------------------ #
        N_r = red_xy.shape[0]
        graph2 = self.graph_o2r_template._replace(
            nodes={
                "orig": typed_graph.NodeSet(n_node=jnp.array([N_o]), features=updated_orig),
                "red": typed_graph.NodeSet(n_node=jnp.array([N_r]), features=stat["red_struct"]),
            },
            edges={self.o2r_key: typed_graph.EdgeSet(
                n_edge=jnp.array([o2r_s.size]),
                indices=typed_graph.EdgesIndices(senders=o2r_s, receivers=o2r_r),
                features=stat["o2r_edges"],
            )}
        )
        
        out_o2r = self.ori2red_gnn(graph2)
        latent_red = out_o2r.nodes["red"].features.astype(jnp.float32)  # (N_r, B, L)
        
        # ---------------- Reduced → Reduced -------------------- #
        graph3 = self.graph_r2r_template._replace(
            nodes={"red": typed_graph.NodeSet(
                n_node=jnp.array([N_r]), features=latent_red
            )},
            edges={self.rr_key: typed_graph.EdgeSet(
                n_edge=jnp.array([rr_s.size]),
                indices=typed_graph.EdgesIndices(senders=rr_s, receivers=rr_r),
                features=stat["r2r_edges"],
            )}
        )
        
        out_rr = self.red2red_gnn(graph3)
        updated_red = out_rr.nodes["red"].features
        
        # ---------------- Reduced → Original ------------------- #
        graph4 = self.graph_r2o_template._replace(
            nodes={
                "red": typed_graph.NodeSet(n_node=jnp.array([N_r]), features=updated_red),
                "orig": typed_graph.NodeSet(n_node=jnp.array([N_o]), features=updated_orig),
            },
            edges={self.r2o_key: typed_graph.EdgeSet(
                n_edge=jnp.array([r2o_s.size]),
                indices=typed_graph.EdgesIndices(senders=r2o_s, receivers=r2o_r),
                features=stat["r2o_edges"],
            )}
        )
        
        out_r2o = self.red2ori_gnn(graph4)
        pred = out_r2o.nodes["orig"].features.astype(jnp.float32)  # (N_o, B, C_out)
        
        return pred.transpose(1, 0, 2)  # (B, N_o, C_out)

    @staticmethod
    def _broadcast(arr: jnp.ndarray, B: int) -> jnp.ndarray:
        return jnp.broadcast_to(arr[:, None, :], (arr.shape[0], B, arr.shape[1]))

    @staticmethod
    def _ensure_same_shapes(graphs: Sequence[Dict[str, jnp.ndarray]]):
        base = graphs[0]
        for g in graphs[1:]:
            for k, v in base.items():
                if v.shape != g[k].shape:
                    raise AssertionError(
                        f"Batched samples must share the same topology. Mismatch in '{k}'."
                    )