import warnings

import torch
import torch.distributions as D
from ema_pytorch import EMA
from jaxtyping import Float, Int
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.models.generative_model.generative_model import GenerativeModel
from src.utils.sparse_utils import (
    sparse_normalization,
    sparse_random_choice,
    sparse_softmax,
    sparse_sum,
)

warnings.simplefilter("ignore", FutureWarning)


def alpha_log(t, k, A=4.5):
    scale = (A + torch.tensor(k**1.5).log2()) / k
    alpha = scale / (1 - t)
    alpha[k <= 1] = 1e-6
    return alpha


def beta_log(t, k, A=4.5):
    scale = (A + torch.tensor(k**1.5).log2()) / k
    beta = -scale * (1 - t).log()
    beta[k <= 1] = 0
    return beta


class BFN(GenerativeModel):
    """Bayesian Flow Networks (Graves et al.2023) adapted for hierarchical graphs."""

    def __init__(
        self,
        backbone: nn.Module,
        min_sqrt_beta: float = 1e-10,
        epsilon: float = 1e-6,
    ) -> None:
        """Initialize Bayesian Flow Network.

        :param backbone: Backbone.
        :param min_sqrt_beta: Minimum square root beta value. Defaults to `1e-10`.
        :param epsilon: Epsilon value. Defaults to `1e-6`.
        """
        super().__init__()

        self.backbone = backbone
        self.ema = EMA(
            self.backbone,
            beta=0.9999,
            update_after_step=10000,
            update_every=10,
        )
        self.min_sqrt_beta = min_sqrt_beta
        self.epsilon = epsilon

    @property
    def device(self) -> torch.device:
        """Returns device of `backbone` parameters.

        :return: Device of the `backbone`.
        """
        return next(self.parameters()).device

    @torch.no_grad()
    def sample_t(self, data: Batch) -> Float[Tensor, "n_nodes"]:
        """Sample random time in [0, 1] for each graph. Return time for each node.

        :param data: PyG batch object.
        :return: Time of each node.
        """
        t = torch.rand((data.batch_size,), device=self.device)
        return t[data.batch]

    def forward(self, data: Batch, use_ema: bool = False) -> Float[Tensor, "n_edges"]:
        """Perform a forward pass for the Bayesian Flow Network.
        1. Sample time step t according to U([0,1]). Store it in `data.t`.
        2. Sample parameters of input distribution for timestep t. Store it in `data.edge_attr`.
        3. Compute parameters of output distribution.

        :param data: PyG batch object.
        :return: Predicted parameters of output distribution.
        """
        data = self.prepare_data(data)
        t = self.sample_t(data)
        data.t = t

        input_params = self.forward_pass(data, t)
        data.edge_attr = input_params

        if use_ema:
            output_params = self.ema(data)
        else:
            output_params = self.backbone(data)
        return output_params

    def loss(
        self,
        data: Batch,
        pred: Float[Tensor, "n_edges"],
        target: Float[Tensor, "n_edges"],
    ) -> Float[Tensor, "1"]:
        """Compute the continuous time loss for the BFN.

        :param data: PyG batch object.
        :param pred: Predicted parameters of output distribution.
        :param target: One hot encoded data distribution.
        :returns: Average continuous time loss.
        """

        kl = (target - pred).square()
        kl = data.n_parents * sparse_sum(data.edge_index, kl, axis=1)
        alpha = self.t_to_alpha(data.t, data.n_parents)
        continous_time_loss = alpha * kl
        return continous_time_loss.mean()

    @torch.inference_mode()
    def sample(
        self, data: Batch, n_steps: int, t_start: float = 0.0, use_ema: bool = False
    ) -> Float[Tensor, "n_steps n_edges"]:
        """Generate samples using the Bayesian Flow Network.

        :param data: PyG batch object.
        :param n_steps: Number of generation steps.
        :param t_start: Starting time. Defaults to `0.0`.
        :returns: Trajectory of parameters of output distribution.
        """
        pred_trajectory = []
        input_trajectory = []
        data = self.prepare_data(data)
        t = torch.zeros_like(data.t, dtype=torch.float)
        input_params = self.forward_pass(data, t)
        start_step = int(t_start * n_steps) + 1
        for i in range(start_step, n_steps + 1):
            data.edge_attr = input_params
            data.t = t
            if use_ema:
                output_params = self.ema(data)
            else:
                output_params = self.backbone(data)
            input_trajectory.append(input_params)
            pred_trajectory.append(output_params)
            t_edge = t[data.edge_index[0]]
            alpha = self.t_to_alpha(t_edge, data.K) / n_steps
            input_params = self.update_input_params(
                data, input_params, output_params, alpha
            )
            t += 1 / n_steps
        return torch.stack(pred_trajectory), torch.stack(input_trajectory)

    @torch.inference_mode()
    def update_input_params(
        self,
        data: Batch,
        input_params: Float[Tensor, "n_edges"],
        output_params: Float[Tensor, "n_edges"],
        alpha: Float[Tensor, "n_nodes"],
    ) -> Float[Tensor, "n_edges"]:
        """Update input distribution parameters based on sample of receiver distribution.
        1. Draw sample from output distribution (categorical distribution).
        2. Transform this sample to receiver distribution.
        3. Use sample from receiver distribution to update input distribution.

        :param data: PyG batch object.
        :param input_params: Parameters representing the input distribution.
        :param output_params: Parameters representing the predicted output distribution.
        :param alpha: Accuracy parameter depending on current step.
        :returns: Updated parameters of input distribution.
        """
        output_dist_sample = sparse_random_choice(data.edge_index, output_params)

        receiver_dist = self.get_sender_dist(output_dist_sample, data.K, alpha)
        receiver_dist_sample = receiver_dist.sample()

        input_params = self.bayes_update(data, input_params, receiver_dist_sample)
        return input_params

    def prepare_data(self, data: Batch) -> Batch:
        """Precompute the following properties for PyG batch object:
        data.K (Int[Tensor, "n_edges"]): Stores number of potential parents of start node for each edge.

        :param data: PyG batch object.
        :returns: Updated PyG batch object.
        """
        self.beta = self.beta.to(self.device)
        self.alpha = self.alpha.to(self.device)
        data.K = data.n_parents[data.edge_index[0]]
        return data

    def t_to_beta(
        self, t_edge: Float[Tensor, "n_edges"], K: Int[Tensor, "n_edges"]
    ) -> Tensor:
        """Convert time t to square root beta for each edge.

        :param t_edge: Time for each edge.
        :param K: Number of potential parents for each edge.
        :returns: Beta parameter for each edge.
        """
        return beta_log(t_edge, K)

    def t_to_alpha(
        self, t_edge: Float[Tensor, "n_edges"], K: Int[Tensor, "n_edges"]
    ) -> Tensor:
        """Convert time t to square root beta for each edge.

        :param t_edge: Time for each edge.
        :param K: Number of potential parents for each edge.
        :returns: Alpha parameter for each edge.
        """
        return alpha_log(t_edge, K)

    def count_dist(self, data: Batch, beta: Float[Tensor, "n_edges"]) -> D.Distribution:
        """Construct count distribution based on one-hot encoded target and beta.

        :param data: PyG batch object.
        :param beta: Beta parameter for each edge.
        :returns: Sample from the count distribution.
        """
        mean = beta * ((data.K * data.edge_attr_target) - 1)
        std_dev = beta.sqrt() * torch.sqrt(data.K)
        return D.Normal(mean, std_dev, validate_args=False)

    def count_sample(
        self, data: Batch, beta: Float[Tensor, "n_edges"]
    ) -> Float[Tensor, "n_edges"]:
        """Sample from the count distribution.

        :param data: PyG batch object.
        :param beta: Beta parameter for each edge.
        :returns: Sample from the count distribution.
        """
        return self.count_dist(data, beta).rsample()

    def get_alpha(self, data: Batch, i: int, n_steps: int) -> Float[Tensor, "n_edges"]:
        """Return alpha for each edge at step i according to the flow schedule.

        :param data: PyG batch object.
        :param i: Current step.
        :param n_steps: Total number of steps.
        :returns: Alpha value for each edge.
        """
        max_sqrt_beta = self.max_sqrt_beta[data.K]
        return ((max_sqrt_beta / n_steps) ** 2) * (2 * i - 1)

    def get_sender_dist(
        self,
        x: Float[Tensor, "n_edges"],
        K: Float[Tensor, "n_edges"],
        alpha: Float[Tensor, "n_edges"],
    ) -> D.Distribution:
        """Return the sender distribution with accuracy alpha obtained by adding appropriate noise
        to the data x.

        :param x: Input tensor.
        :param K: Number of potential parents for each edge.
        :param alpha: Accuracy parameter for each edge.
        :returns: Sender distribution.
        """
        mean = alpha * ((K * x) - 1)
        std = (K * alpha).sqrt()
        return D.Normal(mean, std)

    def bayes_update(
        self,
        data: Batch,
        input_params: Float[Tensor, "n_edges"],
        sample: Float[Tensor, "n_edges"],
    ) -> Float[Tensor, "n_edges"]:
        """Update the distribution parameters using Bayes' theorem in light of noisy sample y.

        :param input_params: Current input parameters.
        :param sample: Noisy sample of the receiver distribution.
        :returns: Updated input parameters as a tuple of tensors.
        """
        new_input_params = input_params * sample.exp()
        new_input_params = sparse_normalization(data.edge_index, new_input_params)
        return new_input_params

    @torch.no_grad()
    def forward_pass(
        self, data: Batch, t: Float[Tensor, "n_nodes"]
    ) -> Float[Tensor, "n_edges"]:
        """Return parameters input distribution at time t conditioned on data.

        :param data: PyG batch object.
        :param t: Time tensor.
        :returns: Parameters of input distribution at time t.
        """
        t_edge = t[data.edge_index[0]]
        beta = self.t_to_beta(t_edge, data.K)
        logits = self.count_sample(data, beta)
        probs = sparse_softmax(data.edge_index, logits, axis=1)
        input_params = torch.where(t_edge == 0, torch.ones_like(probs) / data.K, probs)
        return input_params
