import warnings

import torch
import torch.distributions
from ema_pytorch import EMA
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch
from torchdyn.core import NeuralODE

from src.models.generative_model.generative_model import GenerativeModel

warnings.simplefilter("ignore", FutureWarning)


class CatFlow(GenerativeModel):
    """
    CatFlow from Variational Flow Matching for Graph Generation (Eijkelboom et al. 2024).
    However we use not the optimal transport conditional vector field but the
    a vector field that leads to conditional probability paths of the Bayesian Flow Network.
    """

    def __init__(
        self,
        backbone: nn.Module,
        min_sqrt_beta: float = 1e-10,
        epsilon: float = 1e-6,
    ) -> None:
        """Initialize Bayesian Flow Network.

        :param backbone: Backbone.
        :param min_sqrt_beta: Minimum square root beta value. Defaults to `1e-10`.
        :param epsilon: Epsilon value. Defaults to `1e-6`.
        """
        super().__init__()

        self.backbone = backbone
        self.ema = EMA(
            self.backbone,
            beta=0.999,
            update_after_step=128,
            update_every=1,
        )

    @property
    def device(self) -> torch.device:
        """Returns device of `backbone` parameters.

        :return: Device of the `backbone`.
        """
        return next(self.parameters()).device

    @torch.no_grad()
    def sample_t(self, data: Batch) -> Float[Tensor, "n_nodes"]:
        """Sample random time in [0, 1] for each graph. Return time for each node.

        :param data: PyG batch object.
        :return: Time of each node.
        """
        t = torch.rand((data.batch_size,), device=self.device)
        return t[data.batch]

    def forward(self, data: Batch, use_ema: bool = False) -> Float[Tensor, "n_edges"]:
        """Perform a forward pass for the Bayesian Flow Network.
        1. Sample time step t according to U([0,1]). Store it in `data.t`.
        2. Sample parameters of input distribution for timestep t. Store it in `data.edge_attr`.
        3. Compute parameters of output distribution.

        :param data: PyG batch object.
        :return: Predicted parameters of output distribution.
        """
        t = self.sample_t(data)
        data.t = t

        input_params = self.forward_pass(data, t)
        data.edge_attr = input_params

        if use_ema:
            output_params = self.ema(data)
        else:
            output_params = self.backbone(data)
        return output_params

    def loss(
        self,
        data: Batch,
        pred: Float[Tensor, "n_edges"],
        target: Float[Tensor, "n_edges"],
    ) -> Float[Tensor, "1"]:
        """Compute cross entropy loss.

        :param data: PyG batch object.
        :param pred: Predicted edge probabilities.
        :param target: One hot encoded target edges.
        :returns: Average cross entropy loss.
        """
        target_mask = target == 1
        nll = -torch.log(pred[target_mask])
        return nll.mean()

    @torch.no_grad()
    def forward_pass(
        self, data: Batch, t: Float[Tensor, "n_nodes"]
    ) -> Float[Tensor, "n_edges"]:
        """Interpolate between initial edge probabilities and one hot encoded target edges.

        :param data: PyG batch object.
        :param t: Time tensor.
        :returns: Edge probabilities at time t.
        """
        t_edge = t[data.edge_index[0]]
        x = torch.randn_like(data.edge_attr, device=self.device) * 0.25 + 0.5
        x = x.clamp(min=0.0001)
        return (1 - t_edge) * x + t_edge * data.edge_attr_target

    @torch.inference_mode()
    def sample(
        self, data: Batch, n_steps: int, t_start: float = 0.0, use_ema: bool = False
    ) -> Float[Tensor, "n_steps n_edges"]:
        """Generate samples using CatFlow.

        :param data: PyG batch object.
        :param n_steps: Number of generation steps.
        :param t_start: Starting time. Defaults to `0.0`.
        :returns: Trajectory of edge probabilities.
        """

        t_span = torch.linspace(t_start, 1, n_steps, device=self.device)
        t_start = torch.ones_like(data.t) * t_start
        edge_attr_start = self.forward_pass(data, t_start)

        def vector_field(t, edge_attr, args):
            data.t = torch.ones_like(data.t) * t
            data.edge_attr = edge_attr
            if use_ema:
                edge_attr_pred = self.ema(data)
            else:
                edge_attr_pred = self.backbone(data)
            return (edge_attr_pred - edge_attr) / (1 - t)

        ode = NeuralODE(
            vector_field=vector_field,
            solver="euler",
            sensitivity="adjoint",
            atol=1e-4,
            rtol=1e-4,
        )
        return ode.trajectory(edge_attr_start, t_span), ode.trajectory(
            edge_attr_start, t_span
        )
