import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch_geometric.utils import scatter

from .flow_matching import FlowMatchingModel
from .model import Model
from .. import Graph


def batch_wise_mean(
    field: torch.Tensor,
    batch: torch.LongTensor,
) -> torch.Tensor:
    r"""Compute the batch-wise mean of a field.
    
        Args:
            field (Tensor): The field. Dimension: (num_nodes, num_features).
            batch (LongTensor): The batch vector. Dimension: (num_nodes).

        Returns:
            Tensor: The batch-wise mean. Dimension: (batch_size).
    """
    assert field.dim() == 1 or field.dim() == 2, 'field must be one- or two-dimensional'
    if field.dim() == 2:
        field = field.mean(dim=1) # Dimension: (num_nodes)
    batch_size = batch.max().item() + 1
    return scatter(field, batch, dim=0, dim_size=batch_size, reduce='mean') # Dimension: (batch_size)


class FlowMatchingLoss(nn.Module):
    r"""Loss function for the flow matching model.
    
        Args:
            model (FlowMatchingModel): The flow matching model.
            graph (Graph): The input graph.
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(
        self,
        model: FlowMatchingModel,
        graph: Graph,
    ) -> torch.Tensor:
        # Inference
        pred_v = model(graph) # This is the predicted advection field
        # Check the shapes
        assert pred_v.shape == graph.advection_field.shape, f'output.shape = {pred_v.shape}, advection_field.shape = {graph.advection_field.shape}'
        # Compute the loss
        return batch_wise_mean((pred_v - graph.advection_field)**2, graph.batch) # Dimension: (batch_size)