import os
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

import einops
import lightning as L
import pandas as pd
import torch
import wandb
import yaml
from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch import Tensor, nn
from torch_geometric.data import Data
from torchmetrics.segmentation import MeanIoU

from src.modules.act import GEGLU
from src.utils.constants import ATOM_ENCODING
from src.utils.logging_utils import load_run_config_from_wb
from src.utils.plotting import plot_density_point_cloud, plot_occ_pointcloud
from src.utils.torch_utils import select_single_sample_index


class OccupancyPhysicsParticleLitModule(L.LightningModule):
    def __init__(
        self,
        latent_model: torch.nn.Module,
        loss_function: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler = None,
        compile: bool = False,
        first_stage_model_ckpt: str = None,
        first_stage_model_config: str = None,
        first_stage_model_wandb_project: str = None,
        first_stage_model_wandb_id: str = None,
        num_classes: int = 5,
    ):
        super().__init__()
        self.save_hyperparameters(logger=False, ignore=["latent_model", "loss_function"])
        self.latent_model: torch.nn.Module = latent_model()
        self.loss_function: torch.Module = loss_function

        self.initialize_first_stage_model()

        self.mean_iou_train = MeanIoU(num_classes=self.hparams.num_classes)
        self.mean_iou_val = MeanIoU(num_classes=self.hparams.num_classes)

    def initialize_first_stage_model(self):
        if self.hparams.first_stage_model_wandb_id is not None:
            cfg = load_run_config_from_wb(
                entity="add_your_wandb_here",
                project=self.hparams.first_stage_model_wandb_project,
                run_id=self.hparams.first_stage_model_wandb_id,
            )
        elif self.hparams.first_stage_model_config is not None:
            cfg = OmegaConf.load(self.hparams.first_stage_model_config)
        else:
            raise ValueError(
                "Please provide a value for first_stage_model_config or first_stage_model_wandb_id."
            )
        self.first_stage_model = instantiate(cfg.model)

        if self.hparams.first_stage_model_ckpt:
            ckpt_path = (
                self.hparams.first_stage_model_ckpt
                if self.hparams.first_stage_model_ckpt is not None
                else os.path.join(cfg.callbacks.model_checkpoint.dirpath, "last.ckpt")
            )
            checkpoint = torch.load(ckpt_path, map_location=self.first_stage_model.device)
            self.first_stage_model.load_state_dict(checkpoint["state_dict"])

        self.first_stage_model.eval()
        self.first_stage_model.freeze()

    def forward(self, batch: Data) -> Tensor:
        timestep = batch.input_timestep
        pred_latent = self.encode(
            enc_pos=batch.input_enc_pos,
            enc_field=batch.input_enc_field,
            enc_particle_type=batch.input_enc_particle_type,
            enc_pos_batch_index=batch.batch,
            supernode_index=batch.supernode_index,
            supernode_batch_index=batch.supernode_index_batch,
            timestep=timestep,
        )
        n_targets = batch.target_enc_pos.shape[1]
        loss = 0.0
        for jump_idx in range(n_targets):
            pred_latent = self.push_forward(latent=pred_latent, timestep=timestep)
            target_latent = self.encode(
                enc_pos=batch.target_enc_pos[:, jump_idx],
                enc_field=batch.target_enc_field[:, jump_idx],
                enc_particle_type=batch.target_enc_particle_type,
                enc_pos_batch_index=batch.batch,
                supernode_index=batch.supernode_index,
                supernode_batch_index=batch.supernode_index_batch,
                timestep=batch.target_timestep[:, jump_idx],
            )
            # Update timestep for latent
            timestep = batch.target_timestep[:, jump_idx]
            # Calculate loss
            loss = loss + self.loss_function(pred_latent, target_latent)
        return loss

    def model_step(self, batch):
        loss = self.forward(batch=batch)
        return {"loss": loss}

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        loss = self.model_step(batch)
        # Log metrics
        batch_size = batch.batch.max().item() + 1
        self.log("train/loss", loss["loss"], prog_bar=True, batch_size=batch_size)
        return loss

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        input_data, target_data = batch
        timestep = input_data.timestep
        pred_latent = self.encode(
            enc_pos=input_data.enc_pos,
            enc_field=input_data.enc_field,
            enc_particle_type=input_data.enc_particle_type,
            enc_pos_batch_index=input_data.enc_pos_batch,
            supernode_index=input_data.supernode_index,
            supernode_batch_index=input_data.supernode_index_batch,
            timestep=input_data.timestep,
        )
        jump_idx = 0
        pred_latent = self.push_forward(latent=pred_latent, timestep=timestep)
        timestep = target_data.timestep[:, jump_idx]
        # Decode
        preds_occ, preds_field = self.decode(
            latent=pred_latent,
            dec_field_pos=target_data.dec_pos[:, jump_idx],
            dec_occ_pos=target_data.dec_occ_pos[:, jump_idx],
            timestep=timestep,
        )
        # IoU calculation
        batch_size = input_data.batch.max().item() + 1
        preds_occ = einops.rearrange(
            preds_occ,
            "(batch_size n_points) n_classes -> batch_size n_points n_classes",
            batch_size=batch_size,
        )
        target_occ = einops.rearrange(
            target_data.dec_occ_type[:, jump_idx],
            "(batch_size n_points) n_classes -> batch_size n_points n_classes",
            batch_size=batch_size,
        )
        # mean_iou has problems with index tensors -> one hot
        # https://github.com/Lightning-AI/torchmetrics/pull/2572
        # is fixed in a new version
        preds_occ = torch.nn.functional.one_hot(
            preds_occ.argmax(dim=-1), num_classes=self.hparams.num_classes
        )
        target_occ = torch.nn.functional.one_hot(
            target_occ.argmax(dim=-1), num_classes=self.hparams.num_classes
        )
        # Reshape because classes need to be in dim 1
        preds_occ = einops.rearrange(
            preds_occ,
            "batch_size n_points n_classes -> batch_size n_classes n_points",
        )
        target_occ = einops.rearrange(
            target_occ,
            "batch_size n_points n_classes -> batch_size n_classes n_points",
        )
        self.mean_iou_val(
            preds=preds_occ,
            target=target_occ,
        )
        self.log(
            "val/meanIoU",
            self.mean_iou_val,
            prog_bar=True,
            batch_size=batch_size,
            on_step=False,
            on_epoch=True,
        )

    def setup(self, stage: str):
        if self.hparams.compile and stage == "fit":
            self.model = torch.compile(self.model)
        # Freeze first stage model
        for param in self.first_stage_model.parameters():
            param.requires_grad = False

    def configure_optimizers(self) -> Dict[str, Any]:
        # Only train parameters of latent
        optimizer = self.hparams.optimizer(params=self.latent_model.parameters())
        if self.hparams.scheduler is not None:
            if (type(self.hparams.scheduler) is partial) and (
                self.hparams.scheduler.func.__name__ == "LinearWarmupCosineAnnealingLR"
            ):
                interval = "step"
            else:
                interval = "epoch"
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": interval,
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}

    def encode(
        self,
        enc_pos: Tensor,
        enc_field: Tensor,
        enc_particle_type: Tensor,
        enc_pos_batch_index: Tensor,
        supernode_index: Tensor,
        supernode_batch_index: Tensor,
        timestep: Tensor,
    ) -> Dict:
        with torch.no_grad():
            return self.first_stage_model.model.encode(
                enc_pos=enc_pos,
                enc_field=enc_field,
                enc_particle_type=enc_particle_type,
                enc_pos_batch_index=enc_pos_batch_index,
                supernode_index=supernode_index,
                supernode_batch_index=supernode_batch_index,
                timestep=timestep,
            )

    def decode(
        self,
        latent: Tensor,
        dec_field_pos: Tensor,
        dec_occ_pos: Tensor,
        timestep: Tensor,
    ):
        with torch.no_grad():
            return self.first_stage_model.model.decode(
                latent=latent,
                dec_field_pos=dec_field_pos,
                dec_occ_pos=dec_occ_pos,
                timestep=timestep,
            )

    def condition(self, timestep: Tensor):
        with torch.no_grad():
            if self.first_stage_model.model.conditioner is not None:
                condition = self.first_stage_model.model.conditioner(timestep)
            else:
                condition = None
        return condition

    def push_forward(self, latent: Tensor, timestep: Tensor):
        condition = self.condition(timestep)
        next_latent = self.latent_model(latent, condition)
        return next_latent
