import copy
import warnings
from typing import Any, Dict

import torch
from jaxtyping import Float
from lightning import LightningModule
from torch import Tensor, nn
from torch_geometric.data import Batch
from torchmetrics import MeanMetric

from src.models.eval.binary_tree import (
    compute_valid_binary_tree,
    make_valid_binary_tree_per_graph_size_plot,
    make_valid_binary_tree_plot,
)
from src.models.eval.entropy import make_entropy_per_class_plot
from src.models.eval.jet_likelihood import (
    compute_llh_fraction,
    compute_llhs,
    compute_valid_llh_tree,
    make_llh_plot,
)
from src.models.generative_model.generative_model import GenerativeModel

warnings.filterwarnings("ignore", category=DeprecationWarning)


class GinkgoModule(LightningModule):
    """`LightningModule` Wrapper for CatFlow on Ginkgo Dataset."""

    def __init__(
        self,
        generative_model: GenerativeModel,
        optimizer: torch.optim.Optimizer,
        compile: bool,
        n_steps: int = 50,
        n_val_reps: int = 10,
        n_test_reps: int = 100,
        n_test_gen_reps: int = 10,
        t_start: float = 0.0,
        pt_min_sqrt: float = 4.0,
        qcd_rate: float = 1.5,
    ) -> None:
        """Initialize a `GenerationLitModule`.

        :param generative_model: The generation framework.
        :param optimizer: The optimizer to use for training.
        :param compile: Whether to use torch.compile for optimized runtime.
        :param n_steps: Number of steps used for generation.
        :param n_val_reps: Number of repetitions to compute validation loss.
        :param n_test_reps: Number of repetitions to compute test loss.
        :param n_test_gen_reps: Number of repetitions to compute test generation metrics.
        :param t_start: Starting time for generation.
        :param pt_min: Parameter of ginkgo simulator.
        :param qcd_rate: Parameter of ginkgo simulator.
        """
        super().__init__()

        self.save_hyperparameters(logger=False, ignore=["generative_model"])
        self.gen = generative_model

        self.last_val_input_trajectory = None
        self.last_val_pred_trajectory = None
        self.last_val_data = None
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()
        self.test_ema_loss = MeanMetric()
        self.valid_binary_tree = MeanMetric()
        self.valid_llh_tree = MeanMetric()
        self.valid_tree = MeanMetric()
        self.llh_fraction = MeanMetric()
        self.test_metrics = nn.ModuleDict()
        self.val_llhs_pred_list = []
        self.val_llhs_target_list = []
        self.val_n_list = []
        self.llhs_pred_list = []
        self.llhs_target_list = []
        self.n_list = []
        for model in ["", "_ema"]:
            metric_name = f"test{model}/loss"
            self.test_metrics[metric_name] = MeanMetric()
            for n_steps in [5, 10, 50, 100, 500, 1000]:
                for metric in [
                    "valid_binary_tree",
                    "valid_llh_tree",
                    "valid_tree",
                    "llh_fraction",
                ]:
                    metric_name = f"test{model}/{n_steps}/{metric}"
                    self.test_metrics[metric_name] = MeanMetric()

    def forward(self, data: Batch, use_ema: bool = False) -> Float[Tensor, "n_edges"]:
        """Perform a training pass with the generative model.
        That is:
            1. Sample time step t according to U([0,1]). Store it in `data.time`.
            2. Sample input distribution for timestep t. Store it in `data.edge_attr`.
            3. Compute output distribution. Store it in `data.edge_attr_pred`.

        :param data: PyG batch object.
        :return: The predicted edge attributes.
        """
        return self.gen(data, use_ema)

    def on_train_start(self) -> None:
        """Reset validation metrics after sanity check."""
        self.val_loss.reset()
        self.valid_binary_tree.reset()
        self.valid_llh_tree.reset()
        self.llh_fraction.reset()
        self.valid_tree.reset()
        self.val_llhs_pred_list = []
        self.val_llhs_target_list = []
        self.val_n_list = []

    def model_step(self, data: Batch, use_ema: bool = False) -> Float[Tensor, "1"]:
        """Perform a training pass with the bayesian flow model and compute the continuous time
        loss.

        :param data: PyG batch object.
        :return: The mean loss for the training pass.
        """
        target = data.edge_attr_target
        pred = self.forward(data, use_ema)
        gen_loss = self.gen.loss(data, pred, target)
        return gen_loss

    def training_step(self, data: Batch, batch_idx: int) -> Float[Tensor, "1"]:
        """Perform a single training step on a batch of data from the training set.

        :param data: PyG batch object.
        :param batch_idx: Not used.
        :return: The mean loss for the training step.
        """
        loss = self.model_step(data)
        self.gen.ema.update()
        self.train_loss(loss)
        self.log(
            "train/loss",
            self.train_loss,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            batch_size=data.num_graphs,
        )
        return loss

    def on_train_epoch_end(self) -> None:
        "Lightning hook that is called when a training epoch ends."
        pass

    def validation_step(self, data: Batch, batch_idx: int) -> None:
        """Perform a validation step on a batch of data from the validation set.

        :param data: PyG batch object.
        :param batch_idx: Not used.
        """
        for _ in range(self.hparams.n_val_reps):
            loss = self.model_step(copy.deepcopy(data))
            self.val_loss(loss)
            self.log(
                "val/loss",
                self.val_loss,
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                batch_size=data.num_graphs,
            )

        pred_trajectory, input_trajectory = self.gen.sample(
            data=data,
            n_steps=self.hparams.n_steps,
            t_start=self.hparams.t_start,
            use_ema=True,
        )
        breakpoint()
        self.last_val_pred_trajectory = pred_trajectory
        self.last_val_input_trajectory = input_trajectory
        self.last_val_data = data
        edge_attr_pred = pred_trajectory[-1]
        llhs_pred, llhs_target, n = compute_llhs(
            edge_attr_pred, data, self.hparams.pt_min_sqrt, self.hparams.qcd_rate
        )
        valid_binary_tree = compute_valid_binary_tree(edge_attr_pred, data)
        valid_llh_tree = compute_valid_llh_tree(llhs_pred)
        valid_tree = valid_binary_tree * valid_llh_tree
        self.valid_binary_tree(valid_binary_tree)
        self.valid_llh_tree(valid_llh_tree)
        self.valid_tree(valid_tree)
        self.llh_fraction(compute_llh_fraction(llhs_pred, llhs_target))

        self.log(
            "val/valid_binary_tree",
            self.valid_binary_tree,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            batch_size=data.num_graphs,
        )
        self.log(
            "val/valid_llh_tree",
            self.valid_llh_tree,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            batch_size=data.num_graphs,
        )
        self.log(
            "val/valid_tree",
            self.valid_tree,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            batch_size=data.num_graphs,
        )
        self.log(
            "val/llh_fraction",
            self.llh_fraction,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            batch_size=data.num_graphs,
        )
        self.val_llhs_pred_list += llhs_pred
        self.val_llhs_target_list += llhs_target
        self.val_n_list += n

    def on_validation_epoch_start(self) -> None:
        self.val_llhs_pred_list = []
        self.val_llhs_target_list = []
        self.val_n_list = []

    def on_validation_epoch_end(self) -> None:
        if self.logger:
            plot_input_entropy_per_class = make_entropy_per_class_plot(
                self.last_val_input_trajectory, self.last_val_data
            )
            self.logger.experiment.log(
                {"val/entropy_per_class_plot": plot_input_entropy_per_class}
            )

            plot_pred_entropy_per_class = make_entropy_per_class_plot(
                self.last_val_pred_trajectory, self.last_val_data
            )
            self.logger.experiment.log(
                {"val/pred_entropy_per_class_plot": plot_pred_entropy_per_class}
            )

            plot_binary_tree = make_valid_binary_tree_plot(
                self.last_val_pred_trajectory, self.last_val_data
            )
            self.logger.experiment.log({"val/binary_tree_plot": plot_binary_tree})
            plot_binary_tree_per_graph_size = (
                make_valid_binary_tree_per_graph_size_plot(
                    self.last_val_pred_trajectory[-1], self.last_val_data
                )
            )
            self.logger.experiment.log(
                {"val/binary_tree_per_graph_size": plot_binary_tree_per_graph_size}
            )
            if len(self.val_llhs_pred_list) > 0:
                plot_llh = make_llh_plot(
                    self.val_llhs_pred_list, self.val_llhs_target_list, self.val_n_list
                )
                if plot_llh is not None:
                    self.logger.experiment.log({"val/llh_plot": plot_llh})

    def test_step(self, data: Batch, batch_idx: int) -> None:
        """Perform a single test step on a batch of data from the test set.

        :param data: PyG batch object.
        :param batch_idx: Not used.
        """
        for use_ema in [False, True]:
            model = "" if not use_ema else "_ema"
            for _ in range(self.hparams.n_test_reps):
                loss = self.model_step(copy.deepcopy(data))
                self.test_metrics[f"test{model}/loss"](loss.cpu())
                self.log(
                    f"test_{model}/loss",
                    self.test_metrics[f"test{model}/loss"],
                    on_step=False,
                    on_epoch=True,
                    batch_size=data.num_graphs,
                )
            for n_steps in [5, 10, 50, 100, 500, 1000]:
                for _ in range(self.hparams.n_test_gen_reps):
                    pred_trajectory, _ = self.gen.sample(
                        copy.deepcopy(data),
                        n_steps,
                        self.hparams.t_start,
                        use_ema=use_ema,
                    )
                    edge_attr_pred = pred_trajectory[-1]
                    llhs_pred, llhs_target, n = compute_llhs(
                        edge_attr_pred,
                        data,
                        self.hparams.pt_min_sqrt,
                        self.hparams.qcd_rate,
                    )
                    valid_binary_tree = compute_valid_binary_tree(edge_attr_pred, data)
                    valid_llh_tree = compute_valid_llh_tree(llhs_pred)
                    valid_tree = valid_binary_tree * valid_llh_tree
                    self.test_metrics[f"test{model}/{n_steps}/valid_binary_tree"](
                        valid_binary_tree
                    )
                    self.test_metrics[f"test{model}/{n_steps}/valid_llh_tree"](
                        valid_llh_tree
                    )
                    self.test_metrics[f"test{model}/{n_steps}/valid_tree"](valid_tree)
                    self.test_metrics[f"test{model}/{n_steps}/llh_fraction"](
                        compute_llh_fraction(llhs_pred, llhs_target)
                    )

                    self.log(
                        f"test{model}/{n_steps}/valid_binary_tree",
                        self.test_metrics[f"test{model}/{n_steps}/valid_binary_tree"],
                        on_step=False,
                        on_epoch=True,
                        prog_bar=False,
                        batch_size=data.num_graphs,
                    )
                    self.log(
                        f"test{model}/{n_steps}/valid_llh_tree",
                        self.test_metrics[f"test{model}/{n_steps}/valid_llh_tree"],
                        on_step=False,
                        on_epoch=True,
                        prog_bar=False,
                        batch_size=data.num_graphs,
                    )
                    self.log(
                        f"test{model}/{n_steps}/valid_tree",
                        self.test_metrics[f"test{model}/{n_steps}/valid_tree"],
                        on_step=False,
                        on_epoch=True,
                        prog_bar=False,
                        batch_size=data.num_graphs,
                    )
                    self.log(
                        f"test{model}/{n_steps}/llh_fraction",
                        self.test_metrics[f"test{model}/{n_steps}/llh_fraction"],
                        on_step=False,
                        on_epoch=True,
                        prog_bar=False,
                        batch_size=data.num_graphs,
                    )
                    if n_steps == 1000:
                        self.llhs_pred_list += llhs_pred
                        self.llhs_target_list += llhs_target
                        self.n_list += n

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""
        if self.logger:
            plot_llh = make_llh_plot(
                self.llhs_pred_list, self.llhs_target_list, self.n_list
            )
            if plot_llh is not None:
                self.logger.experiment.log({"test/llh_plot": plot_llh})

    def setup(self, stage: str) -> None:
        """Lightning hook that is called at the beginning of fit (train + validate), validate,
        test, or predict.

        This is a good hook when you need to build models dynamically or adjust something about
        them. This hook is called on every process when using DDP.

        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
        """
        if self.hparams.compile and stage == "fit":
            self.net = torch.compile(self.gen)

    def configure_optimizers(self) -> Dict[str, Any]:
        """Choose what optimizers to use in your optimization.
        Normally you'd need one.

        :return: A dict containing the configured optimizer to be
            used for training.
        """
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        return {"optimizer": optimizer}


if __name__ == "__main__":
    _ = GinkgoModule(None, None, None, None)
