from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

import einops
import lightning as L
import pandas as pd
import torch
import wandb
from torch import Tensor, nn
from torchmetrics.segmentation import MeanIoU

from src.modules.act import GEGLU
from src.utils.constants import ATOM_ENCODING
from src.utils.plotting import plot_density_point_cloud, plot_occ_pointcloud
from src.utils.torch_utils import select_single_sample_index


class OccupancyAutoencoderParticle(nn.Module):
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        conditioner: nn.Module,
        act: nn.Module = GEGLU,
        initialize_weights: Optional[Callable] = None,
        num_classes: int = 2,
    ):
        super().__init__()
        self.encoder = encoder(act=act)
        self.decoder = decoder(act=act)
        if conditioner is not None:
            self.conditioner = conditioner(act=act)
        else:
            self.conditioner = None
        self.num_classes = (
            num_classes - 1
        )  # "no particle" class is not encoded in the encoder because inputs are always valid particles
        self.initialize_weights = initialize_weights

        if self.initialize_weights is not None:
            self.apply(self.initialize_weights)

    def forward(
        self,
        enc_pos: Tensor,
        enc_field: Tensor,
        enc_particle_type: Tensor,
        dec_pos: Tensor,
        dec_occ_pos: Tensor,
        enc_pos_batch_index: Tensor,
        supernode_index: Tensor,
        supernode_batch_index: Tensor,
        timestep: Tensor,
    ) -> Dict:
        latent = self.encode(
            enc_pos=enc_pos,
            enc_field=enc_field,
            enc_particle_type=enc_particle_type,
            enc_pos_batch_index=enc_pos_batch_index,
            supernode_index=supernode_index,
            supernode_batch_index=supernode_batch_index,
            timestep=timestep,
        )
        preds_occ, preds_field = self.decode(
            latent=latent,
            dec_field_pos=dec_pos,
            dec_occ_pos=dec_occ_pos,
            timestep=timestep,
        )
        return {"preds_occ": preds_occ, "preds_field": preds_field}

    def encode(
        self,
        enc_pos: Tensor,
        enc_field: Tensor,
        enc_particle_type: Tensor,
        enc_pos_batch_index: Tensor,
        supernode_index: Tensor,
        supernode_batch_index: Tensor,
        timestep: Tensor,
    ) -> Dict:
        # Stack particle types to field variable
        field = self.prepare_field(field=enc_field, particle_type=enc_particle_type)
        if self.conditioner is not None:
            condition = self.conditioner(timestep)
        else:
            condition = None
        return self.encoder(
            field=field,
            pos=enc_pos,
            batch_index=enc_pos_batch_index,
            supernode_index=supernode_index,
            supernode_batch_index=supernode_batch_index,
            condition=condition,
        )

    def decode(
        self,
        latent: Tensor,
        dec_field_pos: Tensor,
        dec_occ_pos: Tensor,
        timestep: Tensor,
    ):
        if self.conditioner is not None:
            condition = self.conditioner(timestep)
        else:
            condition = None
        return self.decoder(
            x=latent, pos=dec_field_pos, occ_pos=dec_occ_pos, condition=condition
        )
    
    def prepare_field(
        self,
        field: Tensor,
        particle_type: Tensor,
    ):
        # Flatten past field values
        field = einops.rearrange(
            field,
            "n_particles n_timesteps n_dim -> n_particles (n_timesteps n_dim)",
        )
        # Add particle type to input features
        if self.num_classes == 1:
            pass
            # field = torch.cat((field, particle_type.unsqueeze(-1)), dim=-1)
        else:
            field = torch.concat(
                (
                    field,
                    nn.functional.one_hot(particle_type, num_classes=self.num_classes),
                ),
                dim=-1,
            )
        return field


class OccupancyAutoencoderParticleLitModule(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
        loss_function: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler = None,
        compile: bool = False,
        num_classes: int = 5,
    ):
        super().__init__()
        self.save_hyperparameters(logger=False, ignore=["model", "loss_function"])
        self.model: torch.nn.Module = model(num_classes=num_classes)
        self.loss_function: torch.Module = loss_function

        self.mean_iou_train = MeanIoU(num_classes=self.hparams.num_classes)
        self.mean_iou_val = MeanIoU(num_classes=self.hparams.num_classes)

    def forward(
        self,
        enc_pos: Tensor,
        enc_field: Tensor,
        enc_particle_type: Tensor,
        dec_pos: Tensor,
        dec_occ_pos: Tensor,
        enc_pos_batch_index,
        supernode_index: Tensor,
        supernode_batch_index: Tensor,
        timestep: Tensor,
    ) -> Tensor:
        return self.model(
            enc_pos=enc_pos,
            enc_field=enc_field,
            enc_particle_type=enc_particle_type,
            dec_pos=dec_pos,
            dec_occ_pos=dec_occ_pos,
            enc_pos_batch_index=enc_pos_batch_index,
            supernode_index=supernode_index,
            supernode_batch_index=supernode_batch_index,
            timestep=timestep,
        )

    def model_step(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
        preds = self.forward(
            enc_pos=batch.enc_pos,
            enc_field=batch.enc_field,
            enc_particle_type=batch.enc_particle_type,
            dec_pos=batch.dec_pos,
            dec_occ_pos=batch.dec_occ_pos,
            enc_pos_batch_index=batch.enc_pos_batch,
            supernode_index=batch.supernode_index,
            supernode_batch_index=batch.supernode_index_batch,
            timestep=batch.timestep,
        )
        preds_occ = preds["preds_occ"]
        preds_field = preds["preds_field"]
        target_field = einops.rearrange(
            batch.dec_field,
            "n_particles n_timesteps n_dim -> n_particles (n_timesteps n_dim)",
        )
        loss = self.loss_function(
            preds_occ=preds_occ,
            preds_field=preds_field,
            occ=batch.dec_occ_type,
            field=target_field,
        )
        return loss, preds

    def training_step(
        self, batch: Tuple[Tensor, Tensor], batch_idx: int
    ) -> Dict[str, Tensor]:
        loss, preds = self.model_step(batch)
        # IoU calculation
        batch_size = batch.batch.max().item() + 1
        preds_occ = einops.rearrange(
            preds["preds_occ"],
            "(batch_size n_points) n_classes -> batch_size n_points n_classes",
            batch_size=batch_size,
        )
        target_occ = einops.rearrange(
            batch.dec_occ_type,
            "(batch_size n_points) n_classes -> batch_size n_points n_classes",
            batch_size=batch_size,
        )
        # mean_iou has problems with index tensors -> one hot
        # https://github.com/Lightning-AI/torchmetrics/pull/2572
        # is fixed in a new version
        preds_occ = torch.nn.functional.one_hot(
            preds_occ.argmax(dim=-1), num_classes=self.hparams.num_classes
        )
        target_occ = torch.nn.functional.one_hot(
            target_occ.argmax(dim=-1), num_classes=self.hparams.num_classes
        )
        # Reshape because classes need to be in dim 1
        preds_occ = einops.rearrange(
            preds_occ, "batch_size n_points n_classes -> batch_size n_classes n_points"
        )
        target_occ = einops.rearrange(
            target_occ, "batch_size n_points n_classes -> batch_size n_classes n_points"
        )
        self.mean_iou_train(
            preds=preds_occ,
            target=target_occ,
        )

        # Log metrics
        batch_size = batch.batch.max().item() + 1
        self.log("train/loss", loss["loss"], prog_bar=True, batch_size=batch_size)
        self.log(
            "train/occ_loss", loss["occ_loss"], prog_bar=True, batch_size=batch_size
        )
        self.log(
            "train/field_loss", loss["field_loss"], prog_bar=True, batch_size=batch_size
        )
        self.log(
            "train/meanIoU",
            self.mean_iou_train,
            prog_bar=True,
            batch_size=batch_size,
            # on_step=False,
            # on_epoch=True,
        )
        return loss

    def validation_step(
        self, batch: Tuple[Tensor, Tensor], batch_idx: int
    ) -> Dict[str, Tensor]:
        loss, preds = self.model_step(batch)
        # IoU calculation
        batch_size = batch.batch.max().item() + 1
        preds_occ = einops.rearrange(
            preds["preds_occ"],
            "(batch_size n_points) n_classes -> batch_size n_points n_classes",
            batch_size=batch_size,
        )
        target_occ = einops.rearrange(
            batch.dec_occ_type,
            "(batch_size n_points) n_classes -> batch_size n_points n_classes",
            batch_size=batch_size,
        )
        # mean_iou has problems with index tensors -> one hot
        # https://github.com/Lightning-AI/torchmetrics/pull/2572
        # is fixed in a new version
        preds_occ = torch.nn.functional.one_hot(
            preds_occ.argmax(dim=-1), num_classes=self.hparams.num_classes
        )
        target_occ = torch.nn.functional.one_hot(
            target_occ.argmax(dim=-1), num_classes=self.hparams.num_classes
        )
        # Reshape because classes need to be in dim 1
        preds_occ = einops.rearrange(
            preds_occ, "batch_size n_points n_classes -> batch_size n_classes n_points"
        )
        target_occ = einops.rearrange(
            target_occ, "batch_size n_points n_classes -> batch_size n_classes n_points"
        )
        self.mean_iou_val(
            preds=preds_occ,
            target=target_occ,
        )

        self.log("val/loss", loss["loss"], prog_bar=True, batch_size=batch_size)
        self.log("val/occ_loss", loss["occ_loss"], prog_bar=True, batch_size=batch_size)
        self.log(
            "val/field_loss", loss["field_loss"], prog_bar=True, batch_size=batch_size
        )

        self.log(
            "val/meanIoU",
            self.mean_iou_val,
            prog_bar=True,
            batch_size=batch_size,
            on_step=False,
            on_epoch=True,
        )

    def setup(self, stage: str):
        if self.hparams.compile and stage == "fit":
            self.model = torch.compile(self.model)

    def configure_optimizers(self) -> Dict[str, Any]:
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.scheduler is not None:
            if (type(self.hparams.scheduler) is partial) and (
                self.hparams.scheduler.func.__name__ == "LinearWarmupCosineAnnealingLR"
            ):
                interval = "step"
            else:
                interval = "epoch"
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": interval,
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}
