import os
from typing import Tuple, List

import torch
import torch_scatter
from hmpn import get_hmpn_from_graph
from omegaconf import OmegaConf
from torch_geometric.data import Data, Batch
from transformer_blocks.transformer_decoders import TransformerDecoder

from ltsgns_mp.architectures.decoder import get_decoder, Decoder
from ltsgns_mp.architectures.prodmp import build_prodmp, ProDMPPredictor
from ltsgns_mp.architectures.simulators.abstract_simulator import AbstractSimulator
from ltsgns_mp.architectures.util.context_transformer import ContextTransformer
from ltsgns_mp.architectures.util.mlp import MLP
from ltsgns_mp.envs.train_iterator.cnp_train_iterator import CNPTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import unpack_node_features, node_type_mask
from ltsgns_mp.util.loading import get_checkpoint_iteration
from ltsgns_mp.util.own_types import ValueDict


class NPSimulator(AbstractSimulator):
    def __init__(self, config, example_input_batch: CNPTrainBatch, loading_config, device, trajectory_length: int | None = None,):
        target_batch = example_input_batch.target_batch
        context_batch = example_input_batch.context_batch
        gnn_config = config.gnn
        # add proxy z into the target batch
        if config.z_insertion == "gnn":
            proxy_z = torch.zeros(target_batch.x.shape[0], config.z_dim, device=device)
            target_batch.x = torch.cat([target_batch.x, proxy_z], dim=-1)
        else:
            raise NotImplementedError("Decoder insertion not implemented yet.")

        if config.use_prodmp:
            assert trajectory_length is not None
            self._trajectory_length = trajectory_length
            example_input_batch = target_batch[0]
            mp_predictor = build_prodmp(example_input_batch=example_input_batch,
                                        simulator_config=config,
                                        trajectory_length=self._trajectory_length,
                                        device=device)
            decoder_output_dim = mp_predictor.output_size
        else:
            decoder_output_dim = target_batch.pos.shape[-1]
            example_input_batch = target_batch

        super().__init__(config, example_input_batch,
                         decoder_output_dim=decoder_output_dim,
                         loading_config=loading_config,
                         device=device)
        if config.use_prodmp:
            self._mp_predictor: ProDMPPredictor = mp_predictor
        self._context_gnn = get_hmpn_from_graph(example_graph=context_batch,
                                                latent_dimension=gnn_config.latent_dimension,
                                                node_name=keys.MESH,
                                                unpack_output=False,  # return full graph
                                                base_config=gnn_config.base,
                                                device=device)
        if self.config.context_aggregation == "bayesian":
            action_dim = self.config.z_dim * 2  # one for mean, one for var of r
            self._prior_mean = torch.zeros(self.config.z_dim, device=device)
            self._prior_var = torch.ones(self.config.z_dim, device=device)
        else:
            action_dim = self.config.z_dim
            # define r_to_z networks
            config = OmegaConf.create(dict(activation_function="relu",
                                           add_output_layer=False,
                                           num_layers=1,
                                           regularization={
                                               "dropout": config.decoder.regularization.dropout,
                                           },
                                           ))
            self._r_to_mean = torch.nn.Sequential(
                MLP(in_features=self.config.z_dim, latent_dimension=self.config.decoder.latent_dimension, config=config, device=device),
                torch.nn.Linear(self.config.decoder.latent_dimension, self.config.z_dim, device=device))
            self._r_to_var = torch.nn.Sequential(
                MLP(in_features=self.config.z_dim, latent_dimension=self.config.decoder.latent_dimension, config=config, device=device),
                torch.nn.Linear(self.config.decoder.latent_dimension, self.config.z_dim, device=device))
        self._context_decoder: Decoder = get_decoder(config=self.config.decoder,
                                                     action_dim=action_dim,
                                                     device=device,
                                                     input_dimensions=self._input_dimensions(),
                                                     simulator_class=str(type(self)).split(".")[-1].split("'")[0])

        self.load_weights(loading_config, device)
        self._softplus = torch.nn.Softplus(self.config.softplus_beta)

    def forward(self, batch: Batch | Data, initial_time: int | None = None) -> torch.Tensor:
        processed_batch = self.gnn(batch)
        mesh_features = unpack_node_features(processed_batch, node_type=keys.MESH)
        decoded_batch = self.decoder(mesh_features)
        if self.config.use_prodmp:
            assert initial_time is not None
            mesh_type_mask = node_type_mask(batch[0], keys.MESH)
            pos = batch[0].pos[mesh_type_mask]
            prev_pos = batch.context_node_positions[0, max(initial_time - 1, 0)]
            vel = pos - prev_pos
            pos = pos.repeat(len(batch), 1)
            vel = vel.repeat(len(batch), 1)
            initial_time = initial_time * torch.ones(pos.shape[0], device=pos.device)
            prediction = self._mp_predictor(
                pos=pos,
                vel=vel,
                basis_weights=decoded_batch,
                prediction_times=None,
                output_vel=False,
                initial_time=initial_time
            )
            return prediction
        else:
            # velocities of shape (num_nodes, action_dimension)
            return decoded_batch

    def add_z_to_batch(self, batch: Batch | Data, z: torch.Tensor) -> Batch | Data:
        if isinstance(batch, Batch):
            if self.config.node_aggregation is not None:
                # global z for all nodes, repeat for all nodes
                z = z.repeat(batch.num_nodes // len(batch), 1)
            # add z into batch
            # z needs to be repeated for len(batch) times
            z = z.repeat(len(batch), 1)
            x_description = batch.x_description[0]
        else:
            if self.config.node_aggregation is not None:
                z = z.repeat(batch.num_nodes, 1)
            x_description = batch.x_description

        if "z_feature" in x_description:
            # overwrite previous z features
            z_idx = x_description.index("z_feature")
            batch.x[:, z_idx:z_idx + self.config.z_dim] = z
        else:
            batch.x = torch.cat([batch.x, z], dim=-1)
            if isinstance(batch, Batch):
                batch.x_description = [desc + ["z_feature"] * self.config.z_dim for desc in batch.x_description]
            else:
                batch.x_description += ["z_feature"] * self.config.z_dim
        return batch

    def compute_task_posterior(self, context_batch: Batch | Data) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Compute the posterior of the task, i.e., the mean and var of the latent variable z
        :param context_batch:
        :return:
        """
        processed_batch = self._context_gnn(context_batch)
        node_features = processed_batch.x
        context_decoder_output = self._context_decoder(node_features)
        # assumes that all graphs have the same number of nodes
        num_nodes_per_graph = context_batch.num_nodes // len(context_batch)
        context_decoder_output = context_decoder_output.reshape((len(context_batch), num_nodes_per_graph, -1))
        # context_decoder_output has shape (context_size, num_nodes_per_graph, r_dim (*2 if bayesian))
        if self.config.node_aggregation is None:
            context_decoder_output = context_decoder_output
        elif self.config.node_aggregation == "mean":
            context_decoder_output = context_decoder_output.mean(dim=1, keepdim=True)
        elif self.config.node_aggregation == "max":
            context_decoder_output = context_decoder_output.max(dim=1, keepdim=True).values
        else:
            raise NotImplementedError(f"Node aggregation {self.config.node_aggregation} not implemented.")
        if self.config.context_aggregation == "bayesian":
            # split the feature dim to get mean and var
            r_dim = self.config.z_dim
            non_aggregated_r = context_decoder_output[..., :r_dim]
            non_aggregated_var = self._softplus(context_decoder_output[..., r_dim:])  # do the output transformation to get a valid variance
            # both have shape  (context_size, 1 or num_nodes_per_graph, r_dim)
            # do the Bayesian Aggregation from Volpp. et. al 2021
            z_var = (self._prior_var.reciprocal() + torch.sum(non_aggregated_var.reciprocal(),
                                                              dim=0)).reciprocal()  # has shape (1 or num_nodes_per_graph, r_dim)
            z_mean = self._prior_mean + z_var * torch.sum((non_aggregated_r - self._prior_mean) / non_aggregated_var,
                                                          dim=0)  # has shape (1 or num_nodes_per_graph, r_dim)
            return z_mean, z_var
        elif self.config.context_aggregation == "max":
            r = context_decoder_output.max(dim=0, keepdim=False).values
        elif self.config.context_aggregation == "mean":
            r = context_decoder_output.mean(dim=0, keepdim=False)
        else:
            raise NotImplementedError(f"Context aggregation {self.config.context_aggregation} not implemented.")
        # use r_to_z networks to get z
        z_mean = self._r_to_mean(r)
        z_var = self._softplus(self._r_to_var(r))
        return z_mean, z_var

    def _input_dimensions(self) -> ValueDict:
        return {keys.PROCESSOR_DIMENSION: self.config.gnn.latent_dimension, }

    def _get_all_state_dicts(self) -> ValueDict:
        save_dict = {"gnn_params": self.gnn.state_dict(),
                     "decoder_params": self.decoder.state_dict(),
                     "context_gnn_params": self._context_gnn.state_dict(),
                     "context_decoder_params": self._context_decoder.state_dict()
                     }
        if self.config.context_aggregation != "bayesian":
            save_dict["r_to_mean_params"] = self._r_to_mean.state_dict()
            save_dict["r_to_var_params"] = self._r_to_var.state_dict()
        return save_dict

    def load_all_state_dicts(self, state_dict: ValueDict):
        self.gnn.load_state_dict(state_dict["gnn_params"])
        self.decoder.load_state_dict(state_dict["decoder_params"])
        self._context_gnn.load_state_dict(state_dict["context_gnn_params"])
        self._context_decoder.load_state_dict(state_dict["context_decoder_params"])
        if self.config.context_aggregation != "bayesian":
            self._r_to_mean.load_state_dict(state_dict["r_to_mean_params"])
            self._r_to_var.load_state_dict(state_dict["r_to_var_params"])

    def get_parameter_lists_for_optimizer(self) -> Tuple[List, List]:
        params = [self.gnn.parameters(), self.decoder.parameters(), self._context_gnn.parameters(), self._context_decoder.parameters()]
        types = ["gnn", "decoder", "gnn", "decoder"]
        if self.config.context_aggregation != "bayesian":
            params.append(self._r_to_mean.parameters())
            params.append(self._r_to_var.parameters())
            types.append("decoder")
            types.append("decoder")
        return params, types
