import os
from typing import Tuple, List

import torch
import torch_scatter
from hmpn import get_hmpn_from_graph
from torch_geometric.data import Data, Batch
from transformer_blocks.transformer_decoders import TransformerDecoder

from ltsgns_mp.architectures.decoder import get_decoder, Decoder
from ltsgns_mp.architectures.prodmp import build_prodmp, ProDMPPredictor
from ltsgns_mp.architectures.simulators.abstract_simulator import AbstractSimulator
from ltsgns_mp.architectures.util.context_transformer import ContextTransformer
from ltsgns_mp.envs.train_iterator.cnp_train_iterator import CNPTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import unpack_node_features, node_type_mask
from ltsgns_mp.util.loading import get_checkpoint_iteration
from ltsgns_mp.util.own_types import ValueDict


class CNPSimulator(AbstractSimulator):

    def __init__(self, config, example_input_batch: CNPTrainBatch, loading_config, device, trajectory_length: int | None = None):
        target_batch = example_input_batch.target_batch
        context_batch = example_input_batch.context_batch
        gnn_config = config.gnn
        # add proxy z into the target batch
        if config.use_meta_learning:
            proxy_z = torch.zeros(target_batch.x.shape[0], config.z_dim, device=device)
            target_batch.x = torch.cat([target_batch.x, proxy_z], dim=-1)

        if config.use_prodmp:
            assert trajectory_length is not None
            self._trajectory_length = trajectory_length
            example_input_batch = target_batch[0]
            mp_predictor = build_prodmp(example_input_batch=example_input_batch,
                                        simulator_config=config,
                                        trajectory_length=self._trajectory_length,
                                        device=device)
            decoder_output_dim = mp_predictor.output_size
        else:
            decoder_output_dim = target_batch.pos.shape[-1]
            example_input_batch = target_batch

        super().__init__(config, example_input_batch,
                         decoder_output_dim=decoder_output_dim,
                         loading_config=loading_config,
                         device=device)
        if config.use_prodmp:
            self._mp_predictor: ProDMPPredictor = mp_predictor
        self._context_gnn = get_hmpn_from_graph(example_graph=context_batch,
                                                latent_dimension=gnn_config.latent_dimension,
                                                node_name=keys.MESH,
                                                unpack_output=False,  # return full graph
                                                base_config=gnn_config.base,
                                                device=device)
        self._context_decoder: Decoder = get_decoder(config=self.config.decoder,
                                                     action_dim=self.config.r_dim,
                                                     device=device,
                                                     input_dimensions=self._input_dimensions(),
                                                     simulator_class=str(type(self)).split(".")[-1].split("'")[0])
        if self.config.context_aggregation == "transformer":
            self._context_transformer = ContextTransformer(self.config.context_transformer, self.config.r_dim, device)

        self._use_r_to_z = self.config.use_r_to_z
        if not self._use_r_to_z:
            assert self.config.r_dim == self.config.z_dim
        if self._use_r_to_z:
            self._r_to_z = torch.nn.Sequential(
                torch.nn.Linear(self.config.r_dim, self.config.z_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(self.config.z_dim, self.config.z_dim)
            ).to(device)

        self.load_weights(loading_config, device)

    def forward(self, batch: Batch | Data, initial_time: int | None = None) -> torch.Tensor:
        processed_batch = self.gnn(batch)
        mesh_features = unpack_node_features(processed_batch, node_type=keys.MESH)
        decoded_batch = self.decoder(mesh_features)  # velocities of shape (num_nodes, action_dimension)
        if self.config.use_prodmp:
            assert initial_time is not None
            mesh_type_mask = node_type_mask(batch[0], keys.MESH)
            pos = batch[0].pos[mesh_type_mask]
            prev_pos = batch.context_node_positions[0, max(initial_time - 1, 0)]
            vel = pos - prev_pos
            pos = pos.repeat(len(batch), 1)
            vel = vel.repeat(len(batch), 1)
            initial_time = initial_time * torch.ones(pos.shape[0], device=pos.device)
            prediction = self._mp_predictor(
                pos=pos,
                vel=vel,
                basis_weights=decoded_batch,
                prediction_times=None,
                output_vel=False,
                initial_time=initial_time
            )
            return prediction
        return decoded_batch

    def add_z_to_batch(self, batch: Batch | Data, z: torch.Tensor) -> Batch | Data:
        if isinstance(batch, Batch):
            if self.config.node_aggregation is not None:
                # global r for all nodes, repeat for all nodes
                z = z.repeat(batch.num_nodes // len(batch), 1)
            # add r into batch
            # r needs to be repeated for len(batch) times
            z = z.repeat(len(batch), 1)
        else:
            if self.config.node_aggregation is not None:
                z = z.repeat(batch.num_nodes, 1)
        batch.x = torch.cat([batch.x, z], dim=-1)
        if isinstance(batch, Batch):
            batch.x_description = [desc + ["z_feature"] * self.config.z_dim for desc in batch.x_description]
        else:
            batch.x_description += ["z_feature"] * self.config.z_dim
        return batch

    def compute_z(self, context_batch: Batch | Data) -> torch.Tensor:
        processed_batch = self._context_gnn(context_batch)
        node_features = processed_batch.x
        non_aggregated_r = self._context_decoder(node_features)
        # assumes that all graphs have the same number of nodes
        num_nodes_per_graph = context_batch.num_nodes // len(context_batch)
        non_aggregated_r = non_aggregated_r.reshape((len(context_batch), num_nodes_per_graph, -1))
        # non aggregated r has shape (context_size, num_nodes_per_graph, r_dim)
        if self.config.node_aggregation is None:
            per_graph_r = non_aggregated_r
        elif self.config.node_aggregation == "mean":
            per_graph_r = non_aggregated_r.mean(dim=1, keepdim=True)
        elif self.config.node_aggregation == "max":
            per_graph_r = non_aggregated_r.max(dim=1, keepdim=True).values
        else:
            raise NotImplementedError(f"Node aggregation {self.config.node_aggregation} not implemented.")
        if self.config.context_aggregation == "max":
            r = per_graph_r.max(dim=0, keepdim=False).values
        elif self.config.context_aggregation == "mean":
            r = per_graph_r.mean(dim=0, keepdim=False)
        elif self.config.context_aggregation == "transformer":
            # sequence and batch dimension are swapped
            per_graph_r = per_graph_r.permute(1, 0, 2)
            r = self._context_transformer(per_graph_r)
        else:
            raise NotImplementedError(f"Context aggregation {self.config.context_aggregation} not implemented.")
        # r to z trafo
        if self._use_r_to_z:
            z = self._r_to_z(r)
        else:
            z = r
        return z

    def _input_dimensions(self) -> ValueDict:
        return {keys.PROCESSOR_DIMENSION: self.config.gnn.latent_dimension, }

    def _get_all_state_dicts(self) -> ValueDict:
        save_dict = {"gnn_params": self.gnn.state_dict(),
                     "decoder_params": self.decoder.state_dict(),
                     "context_gnn_params": self._context_gnn.state_dict(),
                     "context_decoder_params": self._context_decoder.state_dict()
                     }
        if self.config.context_aggregation == "transformer":
            save_dict["context_transformer_params"] = self._context_transformer.state_dict()
        if self._use_r_to_z:
            save_dict["r_to_z_params"] = self._r_to_z.state_dict()
        return save_dict

    def load_all_state_dicts(self, state_dict: ValueDict):
        self.gnn.load_state_dict(state_dict["gnn_params"])
        self.decoder.load_state_dict(state_dict["decoder_params"])
        self._context_gnn.load_state_dict(state_dict["context_gnn_params"])
        self._context_decoder.load_state_dict(state_dict["context_decoder_params"])
        if self.config.context_aggregation == "transformer":
            self._context_transformer.load_state_dict(state_dict["context_transformer_params"])
        if self._use_r_to_z:
            self._r_to_z.load_state_dict(state_dict["r_to_z_params"])

    def get_parameter_lists_for_optimizer(self) -> Tuple[List, List]:
        params = [self.gnn.parameters(), self.decoder.parameters(), self._context_gnn.parameters(), self._context_decoder.parameters()]
        types = ["gnn", "decoder", "gnn", "decoder"]
        if self._use_r_to_z:
            params.append(self._r_to_z.parameters())
            types.append("decoder")
        if self.config.context_aggregation == "transformer":
            params.append(self._context_transformer.parameters())
            types.append("transformer")
        return params, types
