import os
from abc import ABC, abstractmethod
from typing import Mapping, Any, Tuple, List

import torch
from hmpn import AbstractMessagePassingBase, get_hmpn_from_graph
from torch_geometric.data import Batch, Data

from ltsgns_mp.architectures.decoder import get_decoder
from ltsgns_mp.architectures.decoder.decoder import Decoder
from ltsgns_mp.util import keys
from ltsgns_mp.util.loading import get_checkpoint_iteration
from ltsgns_mp.util.own_types import ValueDict, ConfigDict


class AbstractSimulator(torch.nn.Module, ABC):
    """
    Abstract class for a simulator. A simulator is used to predict the next state of a graph or a trajectory.
    We use a pytorch module for that, input is the current state (and possibly z) and output is the next state
    or the trajectory.
    """

    def __init__(self, config, example_input_batch: Data | Batch, decoder_output_dim: int, loading_config: ConfigDict,
                 device):
        super().__init__()
        self.config = config
        self._device = device

        gnn_config = config.gnn
        self._gnn: AbstractMessagePassingBase = get_hmpn_from_graph(example_graph=example_input_batch,
                                                                    latent_dimension=gnn_config.latent_dimension,
                                                                    node_name=keys.MESH,
                                                                    unpack_output=False,  # return full graph
                                                                    base_config=gnn_config.base,
                                                                    device=device)

        self._decoder: Decoder = get_decoder(config=self.config.decoder,
                                             action_dim=decoder_output_dim,
                                             device=device,
                                             input_dimensions=self._input_dimensions(),
                                             simulator_class=str(type(self)).split(".")[-1].split("'")[0])

    def load_weights(self, loading_config, device):
        if loading_config.enable_loading:
            # load the gnn and decoder parameters
            state_dict_name = "mpn_simulator_state_dict"
            epoch = get_checkpoint_iteration(loading_config, state_dict_name=state_dict_name)
            state_dicts = torch.load(os.path.join(loading_config.checkpoint_path, f"{state_dict_name}_{epoch}.pt"),
                                     map_location=device)
            self.load_all_state_dicts(state_dicts)

    @abstractmethod
    def forward(self, batch, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def _input_dimensions(self) -> ValueDict:
        raise NotImplementedError("This method needs to be implemented by the subclass.")

    @property
    def gnn(self):
        return self._gnn

    @property
    def decoder(self):
        return self._decoder

    def save_checkpoint(self, directory: str, iteration: int | str, is_initial_save: bool, is_final_save: bool = False):
        """
        Saves the state dict of the simulator to the specified directory.
        Args:
            directory:
            iteration: current iteration or a string describing this save. f
            is_initial_save:
            is_final_save:

        Returns:

        """
        save_dict = self._get_all_state_dicts()
        if is_final_save:
            file_name = f"{keys.STATEDICT}_final.pt"
        else:
            file_name = f"{keys.STATEDICT}_{iteration}.pt"
        import os
        torch.save(save_dict, os.path.join(directory, file_name))

    def _get_all_state_dicts(self) -> ValueDict:
        save_dict = {"gnn_params": self.gnn.state_dict(),
                     "decoder_params": self.decoder.state_dict(),
                     }
        return save_dict

    def load_all_state_dicts(self, state_dict: ValueDict):
        self.gnn.load_state_dict(state_dict["gnn_params"])
        self.decoder.load_state_dict(state_dict["decoder_params"])

    def get_parameter_lists_for_optimizer(self) -> Tuple[List, List]:
        """
        Returns a list containing all parameters that should be optimized. Also returns a list to indicate what kind of model this is and hence what LR
        should be used.
        :return: Tuple of two lists, the first one containing the parameters and the second one containing the model type
        """
        return [self.gnn.parameters(), self.decoder.parameters()], ["gnn", "decoder"]
