"""GNS designed to operate of the original fine discretization of particles
Adapted from https://github.com/wu375/simple-physics-simulator-pytorch-geometry.
"""

from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

import lightning as L
import numpy as np
import torch
from torch import Tensor, nn
from torch_geometric.nn import radius, radius_graph
from torch_geometric.utils import unbatch
from torch_scatter import scatter_sum

from src.models.gns.encode_process_decode import (
    EncodeProcessDecode,
    get_random_walk_noise_for_position_sequence,
    time_diff,
)
from src.utils.data_utils import QuinticKernel, load_metadata
from src.utils.metric import mean_iou


class GNSf(nn.Module):
    def __init__(
        self,
        particle_dimension,
        node_in,
        edge_in,
        latent_dim,
        num_message_passing_steps,
        mlp_num_layers,
        mlp_hidden_dim,
        noise_std,
        dataset_path,
        num_particle_types=1,
        particle_type_embedding_size=16,
    ):
        super().__init__()
        self.metadata = load_metadata(Path(dataset_path))
        self._connectivity_radius = self.metadata["default_connectivity_radius"]

        self.noise_std = noise_std
        self._num_particle_types = num_particle_types

        if num_particle_types > 1:
            self._particle_type_embedding = nn.Embedding(
                num_particle_types, particle_type_embedding_size
            )  # (9, 16)

        self._encode_process_decode = EncodeProcessDecode(
            node_in=node_in,
            node_out=particle_dimension,
            edge_in=edge_in,
            latent_dim=latent_dim,
            num_message_passing_steps=num_message_passing_steps,
            mlp_num_layers=mlp_num_layers,
            mlp_hidden_dim=mlp_hidden_dim,
        )

    def set_metadata_device(self, device):
        self.device = device
        self._normalization_stats = {
            "acceleration": {
                "mean": torch.FloatTensor(self.metadata["acc_mean"]).to(device),
                "std": torch.sqrt(
                    torch.FloatTensor(self.metadata["acc_std"]) ** 2 + self.noise_std**2
                ).to(device),
            },
            "velocity": {
                "mean": torch.FloatTensor(self.metadata["vel_mean"]).to(device),
                "std": torch.sqrt(
                    torch.FloatTensor(self.metadata["vel_std"]) ** 2 + self.noise_std**2
                ).to(device),
            },
        }
        self._boundaries = (
            torch.tensor(self.metadata["bounds"], requires_grad=False).float().to(device)
        )

    def forward(self):
        pass

    def _build_graph_from_raw(self, position_sequence, n_particles_per_example, particle_types):
        n_total_points = position_sequence.shape[0]
        most_recent_position = position_sequence[:, -1]  # (n_nodes, 2)
        velocity_sequence = time_diff(position_sequence)
        # senders and receivers are integers of shape (E,)
        senders, receivers = self._compute_connectivity(
            most_recent_position, n_particles_per_example, self._connectivity_radius
        )
        node_features = []
        # Normalized velocity sequence, merging spatial an time axis.
        velocity_stats = self._normalization_stats["velocity"]
        normalized_velocity_sequence = (
            velocity_sequence - velocity_stats["mean"]
        ) / velocity_stats["std"]
        flat_velocity_sequence = normalized_velocity_sequence.view(n_total_points, -1)
        node_features.append(flat_velocity_sequence)

        # Normalized clipped distances to lower and upper boundaries.
        # boundaries are an array of shape [num_dimensions, 2], where the second
        # axis, provides the lower/upper boundaries.
        distance_to_lower_boundary = most_recent_position - self._boundaries[:, 0][None]
        distance_to_upper_boundary = self._boundaries[:, 1][None] - most_recent_position
        distance_to_boundaries = torch.cat(
            [distance_to_lower_boundary, distance_to_upper_boundary], dim=1
        )
        normalized_clipped_distance_to_boundaries = torch.clamp(
            distance_to_boundaries / self._connectivity_radius, -1.0, 1.0
        )
        node_features.append(normalized_clipped_distance_to_boundaries)

        if self._num_particle_types > 1:
            particle_type_embeddings = self._particle_type_embedding(particle_types)
            node_features.append(particle_type_embeddings)

        # Collect edge features.
        edge_features = []

        # Relative displacement and distances normalized to radius
        # (E, 2)
        # normalized_relative_displacements = (
        #     torch.gather(most_recent_position, 0, senders) - torch.gather(most_recent_position, 0, receivers)
        # ) / self._connectivity_radius
        normalized_relative_displacements = (
            most_recent_position[senders, :] - most_recent_position[receivers, :]
        ) / self._connectivity_radius
        edge_features.append(normalized_relative_displacements)

        normalized_relative_distances = torch.norm(
            normalized_relative_displacements, dim=-1, keepdim=True
        )
        edge_features.append(normalized_relative_distances)

        return (
            torch.cat(node_features, dim=-1),
            torch.stack([senders, receivers]),
            torch.cat(edge_features, dim=-1),
        )

    def _compute_connectivity(
        self, node_features, n_particles_per_example, radius, add_self_edges=True
    ):
        # handle batches. Default is 2 examples per batch

        # Specify examples id for particles/points
        batch_ids = torch.cat(
            [
                torch.LongTensor([i for _ in range(n)])
                for i, n in enumerate(n_particles_per_example)
            ]
        ).to(self.device)
        # radius = radius + 0.00001 # radius_graph takes r < radius not r <= radius
        edge_index = radius_graph(
            node_features, r=radius, batch=batch_ids, loop=add_self_edges
        )  # (2, n_edges)
        receivers = edge_index[0, :]
        senders = edge_index[1, :]
        return receivers, senders

    def _decoder_postprocessor(self, normalized_acceleration, position_sequence):
        # The model produces the output in normalized space so we apply inverse
        # normalization.
        acceleration_stats = self._normalization_stats["acceleration"]
        acceleration = (normalized_acceleration * acceleration_stats["std"]) + acceleration_stats[
            "mean"
        ]

        # Use an Euler integrator to go from acceleration to position, assuming
        # a dt=1 corresponding to the size of the finite difference.
        most_recent_position = position_sequence[:, -1]
        most_recent_velocity = most_recent_position - position_sequence[:, -2]

        new_velocity = most_recent_velocity + acceleration  # * dt = 1
        new_position = most_recent_position + new_velocity  # * dt = 1
        return new_position

    def predict_positions(self, current_positions, n_particles_per_example, particle_types):
        node_features, edge_index, e_features = self._build_graph_from_raw(
            current_positions, n_particles_per_example, particle_types
        )
        predicted_normalized_acceleration = self._encode_process_decode(
            node_features, edge_index, e_features
        )
        next_position = self._decoder_postprocessor(
            predicted_normalized_acceleration, current_positions
        )
        return next_position

    def predict_accelerations(
        self,
        next_position,
        position_sequence_noise,
        position_sequence,
        n_particles_per_example,
        particle_types,
    ):
        noisy_position_sequence = position_sequence + position_sequence_noise
        node_features, edge_index, e_features = self._build_graph_from_raw(
            noisy_position_sequence, n_particles_per_example, particle_types
        )
        predicted_normalized_acceleration = self._encode_process_decode(
            node_features, edge_index, e_features
        )
        next_position_adjusted = next_position + position_sequence_noise[:, -1]
        target_normalized_acceleration = self._inverse_decoder_postprocessor(
            next_position_adjusted, noisy_position_sequence
        )
        return predicted_normalized_acceleration, target_normalized_acceleration

    def _inverse_decoder_postprocessor(self, next_position, position_sequence):
        """Inverse of `_decoder_postprocessor`."""
        previous_position = position_sequence[:, -1]
        previous_velocity = previous_position - position_sequence[:, -2]
        next_velocity = next_position - previous_position
        acceleration = next_velocity - previous_velocity

        acceleration_stats = self._normalization_stats["acceleration"]
        normalized_acceleration = (acceleration - acceleration_stats["mean"]) / acceleration_stats[
            "std"
        ]
        return normalized_acceleration


class GNSfLitModule(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
        loss_function: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler = None,
        compile: bool = False,
        num_classes: int = 5,
        pos_scale: float = 200,
        vis_sample_idx: Optional[int] = None,
    ):
        super().__init__()
        self.save_hyperparameters(logger=False, ignore=["model", "loss_function"])
        self.model: torch.nn.Module = model
        self.loss_function: torch.Module = loss_function

    def on_fit_start(self):
        self.model.set_metadata_device(self.device)

    def setup_grid_occ_and_field(self):
        # get some dataset info
        ds = self.trainer.datamodule.val_dataset
        self.normalize_vel = ds.normalize_vel
        self.occupancy_radius = ds.occupancy_radius

        # Instantiate SPH kernel for field value interpolation
        self.sph_kernel = QuinticKernel(h=ds.occupancy_radius / 3, dim=ds.ndim)

    def get_grid_occ_and_field(
        self,
        particle_pos: Tensor,  # (n_particles, dim)
        particle_vel: Tensor,  # (n_particles, dim)
        particle_batch_idx: Tensor,  # (n_particles,)
        grid_pos: Tensor,  # (n_grid_points, dim)
        grid_batch_idx: Tensor,  # (n_grid_points,)
        is_normalize_vel: bool = True,
    ) -> Tuple[Tensor, Tensor]:
        """Batched version of src.datasets.mm_occ_datamodule.get_grid_occ_and_field.

        Uses the pyg radius_graph() to compute the connectivity between particles.
        """

        if is_normalize_vel:
            particle_vel = self.normalize_vel(particle_vel)

        # Get connectivity b/n dataset particles and grid for occ/field computation.
        batch_size = particle_batch_idx.max().item() + 1
        edge_index = radius(
            particle_pos,
            grid_pos,
            r=self.occupancy_radius,
            batch_x=particle_batch_idx,
            batch_y=grid_batch_idx,
            batch_size=batch_size,
        )
        receivers = edge_index[0, :]
        senders = edge_index[1, :]

        # Calculate occupancy
        num_nodes = len(grid_batch_idx)
        pred_grid_occ = scatter_sum(
            torch.ones_like(receivers), receivers, dim=0, dim_size=num_nodes
        )
        pred_grid_occ = torch.clamp(pred_grid_occ, 0, 1).to(torch.float)  # (n_grid_points,)

        # Get field data
        displ = particle_pos[senders] - grid_pos[receivers]
        dist = torch.norm(displ, dim=-1, keepdim=True)
        sph_weights = self.sph_kernel.w(dist)
        # use the sph_weights to compute a weighted average of the velocities
        # of the particles in the support of the kernel.
        weighted_velocities = sph_weights * particle_vel[senders]
        summed_vels = scatter_sum(weighted_velocities, receivers, dim=0, dim_size=num_nodes)
        sheparding_denominator = scatter_sum(sph_weights, receivers, dim=0, dim_size=num_nodes)
        mask = sheparding_denominator.squeeze() > 0
        pred_grid_field = torch.zeros_like(grid_pos)  # (n_grid_points, dim)
        pred_grid_field[mask] = summed_vels[mask] / sheparding_denominator[mask]

        # import matplotlib.pyplot as plt
        # vmin, vmax = -3, 3
        # np_f, np_l, ng_f, ng_l = 0, 3642, 0, 8100
        # grid_pos = grid_pos.cpu().numpy()
        # pred_grid_occ = pred_grid_occ.cpu().numpy()
        # particle_pos = particle_pos.cpu().numpy()
        # sheparding_denominator = sheparding_denominator.cpu().numpy()
        # pred_grid_field = pred_grid_field.cpu().numpy()
        # particle_vel = particle_vel.cpu().numpy()

        # np_f, np_l, ng_f, ng_l = 0, 3642, 0, 8100
        # # np_f, np_l, ng_f, ng_l = 3642, -1, 8100, -1
        # fig, axs = plt.subplots(2, 2, figsize=(6,6))
        # axs[0,0].scatter(grid_pos[ng_f:ng_l, 0], grid_pos[ng_f:ng_l, 1], marker=".", s=8, c=pred_grid_occ[ng_f:ng_l], linewidth=0, label="Occupancy")
        # axs[0,0].scatter(particle_pos[np_f:np_l, 0], particle_pos[np_f:np_l, 1], marker=".", s=1, linewidth=0, label="Particles")
        # axs[0,0].set_title("Particles vs occupancy")

        # axs[0,1].scatter(grid_pos[ng_f:ng_l, 0], grid_pos[ng_f:ng_l, 1], vmin=0, vmax=35000, marker=".", s=8, c=sheparding_denominator[ng_f:ng_l,0], linewidth=0, label="Occupancy")
        # axs[0,1].scatter(particle_pos[np_f:np_l, 0], particle_pos[np_f:np_l, 1], marker=".", s=1, linewidth=0, label="Particles")
        # axs[0,1].set_title("Particles vs density field")

        # axs[1,0].scatter(grid_pos[ng_f:ng_l, 0], grid_pos[ng_f:ng_l, 1], c=pred_grid_field[ng_f:ng_l,0], vmin=vmin, vmax=vmax, marker=".", s=8, linewidth=0, label="Field")
        # axs[1,0].scatter(particle_pos[np_f:np_l, 0], particle_pos[np_f:np_l, 1], c=particle_vel[np_f:np_l,0], vmin=vmin, vmax=vmax, marker=".", s=1, linewidth=0, label="Particles")
        # axs[1,0].set_title("x-velocity")

        # axs[1,1].scatter(grid_pos[ng_f:ng_l, 0], grid_pos[ng_f:ng_l, 1], c=pred_grid_field[ng_f:ng_l,1], vmin=vmin, vmax=vmax, marker=".", s=8, linewidth=0, label="Field")
        # axs[1,1].scatter(particle_pos[np_f:np_l, 0], particle_pos[np_f:np_l, 1], c=particle_vel[np_f:np_l,1], vmin=vmin, vmax=vmax, marker=".", s=1, linewidth=0, label="Particles")
        # axs[1,1].set_title("y-velocity")

        # for ax in axs.flatten():
        #     ax.set_xlim([0, 1])
        #     ax.set_ylim([0, 1])
        #     ax.grid()
        #     # ax.legend()
        # plt.tight_layout()
        # plt.savefig("test1.png", dpi=600)
        # plt.close()

        return pred_grid_occ, pred_grid_field

    def on_validation_start(self):
        self.model.set_metadata_device(self.device)
        self.setup_grid_occ_and_field()

    def forward(
        self,
        batch: Tuple[Dict[str, Tensor], Tensor],
    ) -> Tensor:
        particle_data, _ = batch  # labels is the target position
        features = {
            "position": particle_data.enc_pos,
            "n_particles_per_example": particle_data.n_particles_per_example,
            "particle_type": particle_data.particle_type,
        }
        labels = particle_data.target_pos

        sampled_noise = get_random_walk_noise_for_position_sequence(
            features["position"], noise_std_last_step=self.model.noise_std
        ).to(labels.device)
        non_kinematic_mask = (features["particle_type"] != 3).clone().detach()
        sampled_noise *= non_kinematic_mask.view(-1, 1, 1)
        if not self.training:
            sampled_noise *= 0.0

        pred, target = self.model.predict_accelerations(
            next_position=labels,
            position_sequence_noise=sampled_noise,
            position_sequence=features["position"],
            n_particles_per_example=features["n_particles_per_example"],
            particle_types=features["particle_type"],
        )

        return pred, target, non_kinematic_mask

    def model_step(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
        pred, target, non_kinematic_mask = self.forward(batch)

        loss = (pred - target) ** 2
        loss = loss.sum(dim=-1)
        num_non_kinematic = non_kinematic_mask.sum()

        loss = torch.where(non_kinematic_mask.bool(), loss, torch.zeros_like(loss))
        loss = loss.sum() / num_non_kinematic

        # loss = self.loss_function(pred, target)
        return loss, pred

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        loss, preds = self.model_step(batch)

        # Log metrics
        batch_size = batch[0].enc_pos_batch.max().item() + 1
        self.log("train/loss", loss, prog_bar=True, batch_size=batch_size)

        return loss

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        # evaluate acceleration loss
        # import time
        # start = time.time()
        loss, preds = self.model_step(batch)
        batch_size = batch[0].enc_pos_batch.max().item() + 1
        self.log("val/loss", loss, prog_bar=True, batch_size=batch_size)
        # print("#1: ", time.time()-start); start = time.time()

        # evaluate occ and field loss
        new_positions = self.model.predict_positions(
            current_positions=batch[0].enc_pos,
            n_particles_per_example=batch[0].n_particles_per_example,
            particle_types=batch[0].particle_type,
        )
        # print("#2: ", time.time()-start); start = time.time()

        pred_grid_occ, pred_grid_field = self.get_grid_occ_and_field(
            particle_pos=new_positions,
            particle_vel=new_positions - batch[0].enc_pos[:, -1],
            particle_batch_idx=batch[0].enc_pos_batch,
            grid_pos=batch[1].enc_pos,
            grid_batch_idx=batch[1].enc_pos_batch,
        )

        field_mse = nn.functional.mse_loss(pred_grid_field, batch[1].target_grid_field)
        iou = mean_iou(
            pred=pred_grid_occ.to(torch.long),
            target=batch[1].target_grid_occ.to(torch.long),
            num_classes=self.hparams.num_classes,
            batch=batch[1].enc_pos_batch,
        )
        # print("#3: ", time.time()-start)

        self.log("val/fieldMSE", field_mse, prog_bar=True, batch_size=batch_size)
        self.log(
            "val/meanIoU",
            iou.mean().item(),
            prog_bar=True,
            batch_size=batch_size,
            on_step=False,
            on_epoch=True,
        )

    def setup(self, stage: str):
        if self.hparams.compile and stage == "fit":
            self.model = torch.compile(self.model)

    def configure_optimizers(self) -> Dict[str, Any]:
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.scheduler is not None:
            if (type(self.hparams.scheduler) is partial) and (
                self.hparams.scheduler.func.__name__ == "LinearWarmupCosineAnnealingLR"
            ):
                interval = "step"
            else:
                interval = "epoch"
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": interval,
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}
