from typing import Dict

import torch
from omegaconf import OmegaConf

from ltsgns_mp.architectures.decoder.decoder import Decoder
from ltsgns_mp.util.own_types import ConfigDict


def get_decoder(config: ConfigDict, action_dim: int, device: str, input_dimensions: Dict[str, int],
                simulator_class: str) -> Decoder:
    """
    Build a decoder, which consists of a decode_module and a readout_module.
    Args:
        config:
        action_dim:
        device:
        input_dimensions: Dictionary containing the dimensions of the input features
        simulator_class:

    Returns: initialized Decoder

    """
    if simulator_class == "StepSimulator":
        from ltsgns_mp.architectures.util.mlp import MLP
        decode_module = MLP(in_features=sum([value for value in input_dimensions.values()]),  # concat
                            latent_dimension=config.latent_dimension,
                            config=OmegaConf.create(dict(activation_function="relu",
                                                         add_output_layer=False,
                                                         num_layers=1,
                                                         regularization={
                                                             "dropout": config.regularization.dropout,
                                                         },
                                                         )),
                            device=device
                            )
    elif simulator_class == "LTSGNS_MP_Simulator" or simulator_class == "LTSGNS_Step_Simulator":
        from ltsgns_mp.architectures.decoder.ltsgns_mp_decode_module import LTSGNSMPDecodeModule
        decode_module = LTSGNSMPDecodeModule(decoder_config=config,
                                             in_feature_dict=input_dimensions,
                                             device=device)
    elif simulator_class == "CNPSimulator" or simulator_class == "NPSimulator":
        from ltsgns_mp.architectures.util.mlp import MLP
        decode_module = MLP(in_features=sum([value for value in input_dimensions.values()]),  # concat
                            latent_dimension=config.latent_dimension,
                            config=OmegaConf.create(dict(activation_function="relu",
                                                         add_output_layer=False,
                                                         num_layers=1,
                                                         regularization={
                                                             "dropout": config.regularization.dropout,
                                                         },
                                                         )),
                            device=device
                            )

    else:
        raise NotImplementedError(f"Decoder for simulator class {simulator_class} not implemented")

    readout_module = torch.nn.Linear(config.latent_dimension, action_dim)
    decoder = Decoder(decode_module, readout_module, device=device)
    return decoder
