import os
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

import einops
import lightning as L
import pandas as pd
import torch
import wandb
import yaml
from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch import Tensor, nn
from torch_geometric.data import Data
from torchmetrics.segmentation import MeanIoU

from src.modules.act import GEGLU
from src.utils.constants import ATOM_ENCODING
from src.utils.logging_utils import load_run_config_from_wb
from src.utils.plotting import plot_density_point_cloud, plot_occ_pointcloud
from src.utils.torch_utils import select_single_sample_index


class OccupancyPhysicsFullParticleLitModule(L.LightningModule):
    def __init__(
        self,
        latent_model: torch.nn.Module,
        loss_function: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler = None,
        compile: bool = False,
        first_stage_model_ckpt: str = None,
        first_stage_model_config: str = None,
        first_stage_model_wandb_project: str = None,
        first_stage_model_wandb_id: str = None,
        freeze_first_stage_model: bool = False,
        num_classes: int = 5,
    ):
        super().__init__()
        self.save_hyperparameters(
            logger=False, ignore=["latent_model", "loss_function"]
        )
        self.latent_model: torch.nn.Module = latent_model()
        self.loss_function: torch.Module = loss_function

        self.initialize_first_stage_model()

        self.mean_iou_train = MeanIoU(num_classes=self.hparams.num_classes)
        self.mean_iou_val = MeanIoU(num_classes=self.hparams.num_classes)

    def initialize_first_stage_model(self):
        if self.hparams.first_stage_model_wandb_id is not None:
            cfg = load_run_config_from_wb(
                entity="add_your_wandb_here",
                project=self.hparams.first_stage_model_wandb_project,
                run_id=self.hparams.first_stage_model_wandb_id,
            )
        elif self.hparams.first_stage_model_config is not None:
            cfg = OmegaConf.load(self.hparams.first_stage_model_config)
        else:
            raise ValueError(
                "Please provide a value for first_stage_model_config or first_stage_model_wandb_id."
            )
        self.first_stage_model = instantiate(cfg.model)

        if self.hparams.first_stage_model_ckpt:
            ckpt_path = (
                self.hparams.first_stage_model_ckpt
                if self.hparams.first_stage_model_ckpt is not None
                else os.path.join(cfg.callbacks.model_checkpoint.dirpath, "last.ckpt")
            )
            checkpoint = torch.load(ckpt_path, map_location=self.device)
            self.first_stage_model.load_state_dict(checkpoint["state_dict"])

        if self.hparams.freeze_first_stage_model:
            self.first_stage_model.eval()
            self.first_stage_model.freeze()

    def forward(self, input_data: Data, target_data: Data) -> Tensor:
        timestep = input_data.timestep
        pred_latent = self.encode(
            enc_pos=input_data.enc_pos,
            enc_field=input_data.enc_field,
            enc_particle_type=input_data.enc_particle_type,
            enc_pos_batch_index=input_data.enc_pos_batch,
            supernode_index=input_data.supernode_index,
            supernode_batch_index=input_data.supernode_index_batch,
            timestep=timestep,
        )
        total_loss = {"loss": 0.0, "occ_loss": 0.0, "field_loss": 0.0}
        preds = {"preds_occ": [], "preds_field": []}
        for jump_idx in range(target_data.dec_pos.shape[1]):
            pred_latent = self.push_forward(latent=pred_latent, timestep=timestep)
            # Update timestep
            timestep = target_data.timestep[:, jump_idx]
            preds_occ, preds_field = self.decode(
                latent=pred_latent,
                dec_field_pos=target_data.dec_pos[:, jump_idx],
                dec_occ_pos=target_data.dec_occ_pos[:, jump_idx],
                timestep=timestep,
            )
            preds["preds_occ"].append(preds_occ)
            preds["preds_field"].append(preds_field)
            # Calculate loss
            target_field = einops.rearrange(
                target_data.dec_field[:, jump_idx],
                "n_particles n_timesteps n_dim -> n_particles (n_timesteps n_dim)",
            )
            loss = self.loss_function(
                preds_occ=preds_occ,
                preds_field=preds_field,
                occ=target_data.dec_occ_type[:, jump_idx],
                field=target_field,
            )
            total_loss["loss"] = total_loss["loss"] + loss["loss"]
            total_loss["occ_loss"] = total_loss["occ_loss"] + loss["occ_loss"]
            total_loss["field_loss"] = total_loss["field_loss"] + loss["field_loss"]
        return total_loss, preds

    def model_step(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
        input_data, target_data = batch
        loss, preds = self.forward(input_data=input_data, target_data=target_data)
        return loss, preds

    def training_step(
        self, batch: Tuple[Tensor, Tensor], batch_idx: int
    ) -> Dict[str, Tensor]:
        loss, preds = self.model_step(batch)
        # Log metrics
        batch_size = batch[0].batch.max().item() + 1
        self.log("train/loss", loss["loss"], prog_bar=True, batch_size=batch_size)
        self.log(
            "train/occ_loss", loss["occ_loss"], prog_bar=True, batch_size=batch_size
        )
        self.log(
            "train/field_loss", loss["field_loss"], prog_bar=True, batch_size=batch_size
        )
        jump_idx = 0
        self.calc_iou(
            preds_occ=preds["preds_occ"][jump_idx],
            target_occ=batch[1].dec_occ_type[:, jump_idx],
            batch_size=batch_size,
            iou_metric=self.mean_iou_train,
        )
        self.log(
            "train/meanIoU",
            self.mean_iou_train,
            prog_bar=True,
            batch_size=batch_size,
        )
        return loss

    def validation_step(
        self, batch: Tuple[Tensor, Tensor], batch_idx: int
    ) -> Dict[str, Tensor]:
        input_data, target_data = batch
        timestep = input_data.timestep
        pred_latent = self.encode(
            enc_pos=input_data.enc_pos,
            enc_field=input_data.enc_field,
            enc_particle_type=input_data.enc_particle_type,
            enc_pos_batch_index=input_data.enc_pos_batch,
            supernode_index=input_data.supernode_index,
            supernode_batch_index=input_data.supernode_index_batch,
            timestep=input_data.timestep,
        )
        jump_idx = 0
        pred_latent = self.push_forward(latent=pred_latent, timestep=timestep)
        timestep = target_data.timestep[:, jump_idx]
        # Decode
        preds_occ, preds_field = self.decode(
            latent=pred_latent,
            dec_field_pos=target_data.dec_pos[:, jump_idx],
            dec_occ_pos=target_data.dec_occ_pos[:, jump_idx],
            timestep=timestep,
        )
        batch_size = batch[0].batch.max().item() + 1
        self.calc_iou(
            preds_occ=preds_occ,
            target_occ=batch[1].dec_occ_type[:, jump_idx],
            batch_size=batch_size,
            iou_metric=self.mean_iou_val,
        )
        self.log(
            "val/meanIoU",
            self.mean_iou_val,
            prog_bar=True,
            batch_size=batch_size,
            on_step=False,
            on_epoch=True,
        )

    def setup(self, stage: str):
        if self.hparams.compile and stage == "fit":
            self.model = torch.compile(self.model)

    def configure_optimizers(self) -> Dict[str, Any]:
        if self.hparams.freeze_first_stage_model:
            # Only train parameters of latent
            optimizer = self.hparams.optimizer(params=self.latent_model.parameters())
        else:
            optimizer = self.hparams.optimizer(
                params=list(self.latent_model.parameters())
                + list(self.first_stage_model.parameters())
            )
        if self.hparams.scheduler is not None:
            if (type(self.hparams.scheduler) is partial) and (
                self.hparams.scheduler.func.__name__ == "LinearWarmupCosineAnnealingLR"
            ):
                interval = "step"
            else:
                interval = "epoch"
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": interval,
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}

    def encode(
        self,
        enc_pos: Tensor,
        enc_field: Tensor,
        enc_particle_type: Tensor,
        enc_pos_batch_index: Tensor,
        supernode_index: Tensor,
        supernode_batch_index: Tensor,
        timestep: Tensor,
    ) -> Dict:
        return self.first_stage_model.model.encode(
            enc_pos=enc_pos,
            enc_field=enc_field,
            enc_particle_type=enc_particle_type,
            enc_pos_batch_index=enc_pos_batch_index,
            supernode_index=supernode_index,
            supernode_batch_index=supernode_batch_index,
            timestep=timestep,
        )

    def decode(
        self,
        latent: Tensor,
        dec_field_pos: Tensor,
        dec_occ_pos: Tensor,
        timestep: Tensor,
    ):
        return self.first_stage_model.model.decode(
            latent=latent,
            dec_field_pos=dec_field_pos,
            dec_occ_pos=dec_occ_pos,
            timestep=timestep,
        )

    def condition(self, timestep: Tensor):
        if self.first_stage_model.model.conditioner is not None:
            condition = self.first_stage_model.model.conditioner(timestep)
        else:
            condition = None
        return condition

    def push_forward(self, latent: Tensor, timestep: Tensor):
        condition = self.condition(timestep)
        next_latent = self.latent_model(latent, condition)
        return next_latent

    def calc_iou(
        self,
        preds_occ: Tensor,
        target_occ: Tensor,
        batch_size: int,
        iou_metric: MeanIoU,
    ):
        # IoU calculation
        preds_occ = einops.rearrange(
            preds_occ,
            "(batch_size n_points) n_classes -> batch_size n_points n_classes",
            batch_size=batch_size,
        )
        target_occ = einops.rearrange(
            target_occ,
            "(batch_size n_points) n_classes -> batch_size n_points n_classes",
            batch_size=batch_size,
        )
        # mean_iou has problems with index tensors -> one hot
        # https://github.com/Lightning-AI/torchmetrics/pull/2572
        # is fixed in a new version
        preds_occ = torch.nn.functional.one_hot(
            preds_occ.argmax(dim=-1), num_classes=self.hparams.num_classes
        )
        target_occ = torch.nn.functional.one_hot(
            target_occ.argmax(dim=-1), num_classes=self.hparams.num_classes
        )
        # Reshape because classes need to be in dim 1
        preds_occ = einops.rearrange(
            preds_occ, "batch_size n_points n_classes -> batch_size n_classes n_points"
        )
        target_occ = einops.rearrange(
            target_occ, "batch_size n_points n_classes -> batch_size n_classes n_points"
        )
        iou_metric(
            preds=preds_occ,
            target=target_occ,
        )
