from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

import lightning as L
import pandas as pd
import torch
import wandb
from torch import Tensor, nn
from torch_geometric.utils import unbatch
from torchmetrics.classification import (
    MulticlassAUROC,
    MulticlassConfusionMatrix,
    MulticlassPrecision,
    MulticlassRecall,
)

from src.modules.act import GEGLU
from src.utils.constants import ATOM_ENCODING
from src.utils.plotting import plot_density_point_cloud, plot_occ_pointcloud
from src.utils.torch_utils import select_single_sample_index


class OccupancyAutoencoder(nn.Module):
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        act: nn.Module = GEGLU,
        initialize_weights: Optional[Callable] = None,
    ):
        super().__init__()
        self.encoder = encoder(act=act)
        self.decoder = decoder(act=act)
        self.initialize_weights = initialize_weights

        if self.initialize_weights is not None:
            self.apply(self.initialize_weights)

    def forward(
        self,
        enc_pos: Tensor,
        enc_occ: Tensor,
        enc_field: Tensor,
        query_pos: Tensor,
        enc_batch_index: Tensor,
        supernode_index: Tensor,
        supernode_batch_index: Tensor,
    ) -> Dict[str, Tensor]:
        latent = self.encode(
            enc_pos=enc_pos,
            enc_occ=enc_occ,
            enc_field=enc_field,
            enc_batch_index=enc_batch_index,
            supernode_index=supernode_index,
            supernode_batch_index=supernode_batch_index,
        )

        preds_occ, preds_field = self.decode(latent=latent, query_pos=query_pos)
        return {"preds_occ": preds_occ, "preds_field": preds_field}

    def encode(
        self,
        enc_pos: Tensor,
        enc_occ: Tensor,
        enc_field: Tensor,
        enc_batch_index: Tensor,
        supernode_index: Tensor,
        supernode_batch_index: Tensor,
    ) -> Tensor:
        return self.encoder(
            enc_pos,
            enc_occ,
            enc_field,
            enc_batch_index,
            supernode_index,
            supernode_batch_index,
        )

    def decode(self, latent: Tensor, query_pos: Tensor) -> Tensor:
        return self.decoder(latent, query_pos)


class OccupancyAutoencoderLitModule(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
        loss_function: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler = None,
        compile: bool = False,
        num_classes: int = 5,
        pos_scale: float = 200,
        vis_sample_idx: Optional[int] = None,
    ):
        super().__init__()
        self.save_hyperparameters(logger=False, ignore=["model", "loss_function"])
        self.model: torch.nn.Module = model
        self.loss_function: torch.Module = loss_function

        self.initialize()

    def initialize(self):
        self.train_auroc_multi = MulticlassAUROC(num_classes=self.hparams.num_classes)
        self.valid_auroc_multi = MulticlassAUROC(num_classes=self.hparams.num_classes)

        self.train_precision_multi = MulticlassPrecision(num_classes=self.hparams.num_classes)
        self.valid_precision_multi = MulticlassPrecision(num_classes=self.hparams.num_classes)

        self.train_recall_multi = MulticlassRecall(num_classes=self.hparams.num_classes)
        self.valid_recall_multi = MulticlassRecall(num_classes=self.hparams.num_classes)

        self.train_confmat_multi = MulticlassConfusionMatrix(
            num_classes=self.hparams.num_classes, normalize="true"
        )
        self.valid_confmat_multi = MulticlassConfusionMatrix(
            num_classes=self.hparams.num_classes, normalize="true"
        )

        self.train_step_outputs = []
        self.val_step_outputs = []

    def forward(
        self,
        enc_pos: Tensor,
        enc_occ: Tensor,
        enc_field: Tensor,
        query_pos: Tensor,
        enc_batch_index: Tensor,
        supernode_index: Tensor,
        supernode_batch_index: Tensor,
    ) -> Tensor:
        return self.model(
            enc_pos,
            enc_occ,
            enc_field,
            query_pos,
            enc_batch_index,
            supernode_index,
            supernode_batch_index,
        )

    def model_step(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
        preds = self.forward(
            enc_pos=batch.enc_pos,
            enc_occ=batch.enc_occ,
            enc_field=batch.enc_field,
            query_pos=batch.query_pos,
            enc_batch_index=batch.enc_pos_batch,
            supernode_index=batch.supernode_index,
            supernode_batch_index=batch.supernode_index_batch,
        )
        preds_occ = preds["preds_occ"]
        preds_field = preds["preds_field"].squeeze()
        loss = self.loss_function(
            preds_occ=preds_occ,
            preds_field=preds_field,
            occ=batch.query_occ,
            field=batch.query_field,
        )
        return loss, preds

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        loss, preds = self.model_step(batch)
        return {
            "loss": loss["loss"],
            "occ_loss": loss["occ_loss"],
            "field_loss": loss["field_loss"],
            "preds": preds,
        }

    def on_train_batch_end(
        self, outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int
    ) -> None:
        loss, occ_loss, field, preds = (
            outputs["loss"],
            outputs["occ_loss"],
            outputs["field_loss"],
            outputs["preds"],
        )

        # Log metrics
        batch_size = batch.batch.max().item() + 1
        self.log("train/loss", loss, prog_bar=True, batch_size=batch_size)
        self.log("train/occ_loss", occ_loss, prog_bar=True, batch_size=batch_size)
        self.log("train/field_loss", field, prog_bar=True, batch_size=batch_size)

        preds_occ_norm = torch.softmax(preds["preds_occ"], dim=-1)
        class_labels = batch.query_occ.argmax(dim=-1)
        self.train_auroc_multi(preds_occ_norm, class_labels)
        self.train_precision_multi(preds_occ_norm, class_labels)
        self.train_recall_multi(preds_occ_norm, class_labels)

    def on_train_epoch_end(self) -> None:
        self.log("train/auroc_multi", self.train_auroc_multi, prog_bar=True)
        self.log("train/precision_multi", self.train_precision_multi, prog_bar=True)
        self.log("train/recall_multi", self.train_recall_multi, prog_bar=True)

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        loss, preds = self.model_step(batch)
        return {
            "loss": loss["loss"],
            "occ_loss": loss["occ_loss"],
            "field_loss": loss["field_loss"],
            "preds": preds,
        }

    def on_validation_batch_end(
        self, outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int
    ) -> None:
        loss, occ_loss, field_loss, preds = (
            outputs["loss"],
            outputs["occ_loss"],
            outputs["field_loss"],
            outputs["preds"],
        )

        batch_size = batch.batch.max().item() + 1
        self.log("val/loss", loss, prog_bar=True, batch_size=batch_size)
        self.log("val/occ_loss", occ_loss, prog_bar=True, batch_size=batch_size)
        self.log("val/field_loss", field_loss, prog_bar=True, batch_size=batch_size)

        preds_occ_norm = torch.softmax(preds["preds_occ"], dim=-1)
        class_labels = batch.query_occ.argmax(dim=-1)
        self.valid_auroc_multi(preds_occ_norm, class_labels)
        self.valid_precision_multi(preds_occ_norm, class_labels)
        self.valid_recall_multi(preds_occ_norm, class_labels)

        preds_occ_class = preds["preds_occ"].argmax(axis=-1)
        self.valid_confmat_multi.update(preds_occ_class, class_labels)

        self.val_step_outputs.append(
            {
                "preds_occ": preds["preds_occ"].detach().cpu(),
                "preds_field": preds["preds_field"].detach().cpu(),
                "occ": batch.query_occ.detach().cpu(),
                "field": batch.query_field.detach().cpu(),
                "pos": batch.query_pos.detach().cpu(),
                "batch": batch.query_pos_batch.detach().cpu(),
                "atom_pos": batch.atom_pos,
                "atom_type": batch.atom_type,
                "batch_size": batch_size,
            }
        )

    def on_validation_epoch_end(self) -> None:
        self.log("val/auroc_multi", self.valid_auroc_multi, prog_bar=True)
        self.log("val/precision_multi", self.valid_precision_multi, prog_bar=True)
        self.log("val/recall_multi", self.valid_recall_multi, prog_bar=True)
        self.valid_confmat_multi.compute()

        fig, _ = self.valid_confmat_multi.plot(labels=list(ATOM_ENCODING.keys()))
        self.logger.experiment.log(
            {
                "val/confusion_matrix": wandb.Image(
                    fig, caption=f"Confusion Matrix at Epoch {self.current_epoch}"
                ),
            },
        )

        # Visualize a single sample
        if self.hparams.vis_sample_idx is not None:
            batch_idx, local_idx = select_single_sample_index(
                self.hparams.vis_sample_idx, self.val_step_outputs
            )
            sample = self.val_step_outputs[batch_idx]
            occs = unbatch(sample["occ"], sample["batch"])[local_idx].argmax(axis=-1)
            preds_occ = unbatch(sample["preds_occ"], sample["batch"])[local_idx].argmax(axis=-1)
            pos = unbatch(sample["pos"], sample["batch"])[local_idx]
            field = unbatch(sample["field"], sample["batch"])[local_idx]
            preds_field = unbatch(sample["preds_field"], sample["batch"])[local_idx]
            atom_pos = sample["atom_pos"][local_idx]
            atom_type = sample["atom_type"][local_idx]
            self.log_occ_point_cloud(occs, preds_occ, pos, atom_pos, atom_type)
            self.log_density_point_cloud(field, preds_field, pos, atom_pos, atom_type)

        self.valid_confmat_multi.reset()
        self.val_step_outputs.clear()

    def log_occ_point_cloud(
        self, occs: Tensor, preds_occ: Tensor, pos: Tensor, atom_pos: List, atom_type: List
    ) -> None:
        reversed_atom_encoding = {v: k for k, v in ATOM_ENCODING.items()}
        df_occs = pd.DataFrame(pos, columns=["X", "Y", "Z"])
        df_occs = df_occs / self.hparams.pos_scale
        df_occs["occs"] = occs
        df_occs["occs"] = df_occs["occs"].apply(lambda x: reversed_atom_encoding[x])
        df_occs["preds_occ"] = preds_occ
        df_occs["preds_occ"] = df_occs["preds_occ"].apply(lambda x: reversed_atom_encoding[x])

        df_atoms = pd.DataFrame(atom_pos, columns=["X", "Y", "Z"])
        df_atoms["atom_type"] = atom_type
        df_atoms["atom_type"] = df_atoms["atom_type"].apply(lambda x: reversed_atom_encoding[x])

        # Generate plots
        fig_input = plot_occ_pointcloud(
            df_occs=df_occs,
            color_column="occs",
            title=f"Occupancy input - Epoch {self.current_epoch}",
            df_atoms=df_atoms,
        )
        wandb.log({"val/vis/occ/target": wandb.Plotly(fig_input)})

        fig_pred = plot_occ_pointcloud(
            df_occs=df_occs,
            color_column="preds_occ",
            title=f"Occupancy prediction - Epoch {self.current_epoch}",
            df_atoms=df_atoms,
        )
        wandb.log({"val/vis/occ/pred": wandb.Plotly(fig_pred)})

    def log_density_point_cloud(
        self,
        densities: Tensor,
        preds_densities: Tensor,
        pos: Tensor,
        atom_pos: List,
        atom_type: List,
    ) -> None:
        reversed_atom_encoding = {v: k for k, v in ATOM_ENCODING.items()}
        df_densities = pd.DataFrame(pos, columns=["X", "Y", "Z"])
        df_densities = df_densities / self.hparams.pos_scale
        df_densities["densities"] = densities
        df_densities["preds_densities"] = preds_densities

        df_atoms = pd.DataFrame(atom_pos, columns=["X", "Y", "Z"])
        df_atoms["atom_type"] = atom_type
        df_atoms["atom_type"] = df_atoms["atom_type"].apply(lambda x: reversed_atom_encoding[x])

        # Generate plots
        fig_input = plot_density_point_cloud(
            df_densities=df_densities,
            color_column="densities",
            title=f"Density input - Epoch {self.current_epoch}",
            df_atoms=df_atoms,
        )

        wandb.log({"val/vis/dens/target": wandb.Plotly(fig_input)})

        fig_pred = plot_density_point_cloud(
            df_densities=df_densities,
            color_column="preds_densities",
            title=f"Density prediction - Epoch {self.current_epoch}",
            df_atoms=df_atoms,
        )

        wandb.log({"val/vis/dens/pred": wandb.Plotly(fig_pred)})

    def setup(self, stage: str):
        if self.hparams.compile and stage == "fit":
            self.model = torch.compile(self.model)

    def configure_optimizers(self) -> Dict[str, Any]:
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}

    def encode(self, occ: Tensor, pos: Tensor, batch_index: Tensor) -> Tensor:
        return self.model.encode(occ, pos, batch_index)

    def decode(self, latent: Tensor, pos: Tensor) -> Tensor:
        return self.model.decode(latent, pos)
