import os

import torch

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.architectures.simulators.abstract_simulator import AbstractSimulator
from ltsgns_mp.architectures.simulators.get_simulator import get_simulator
from ltsgns_mp.architectures.util.multiple_optimizer import MultipleOptimizer
from ltsgns_mp.envs.env import Env
from ltsgns_mp.util.loading import get_checkpoint_iteration
from ltsgns_mp.util.own_types import ConfigDict


def get_algorithm(config: ConfigDict, env: Env, loading_config: ConfigDict, device: str) -> AbstractAlgorithm:
    algorithm_name = config.name
    match algorithm_name:
        case "mgn":
            from ltsgns_mp.algorithms.mgn import MGN
            algorithm_class = MGN
        case "mgn_cnp_iterator":
            from ltsgns_mp.algorithms.mgn_cnp_iterator import MGNCNPIterator
            algorithm_class = MGNCNPIterator
        case "mgn_mp":
            from ltsgns_mp.algorithms.mgn_mp import MGN_MP
            algorithm_class = MGN_MP
        case "ggns":
            from ltsgns_mp.algorithms.ggns import GGNS
            algorithm_class = GGNS
        case "ltsgns_mp":
            from ltsgns_mp.algorithms.ltsgns_mp import LTSGNS_MP
            algorithm_class = LTSGNS_MP
        case "ltsgns_step":
            from ltsgns_mp.algorithms.ltsgns_step import LTSGNS_Step
            algorithm_class = LTSGNS_Step
        case "cnp":
            from ltsgns_mp.algorithms.cnp import CNP
            algorithm_class = CNP
        case "cnp_mp":
            from ltsgns_mp.algorithms.cnp_mp import CNP_MP
            algorithm_class = CNP_MP
        case "np":
            from ltsgns_mp.algorithms.np import NP
            algorithm_class = NP
        case "np_mp":
            from ltsgns_mp.algorithms.np_mp import NP_MP
            algorithm_class = NP_MP
        case _:
            raise ValueError(f"Unknown algorithm {algorithm_name}")
    simulator = get_simulator(config, env, loading_config=loading_config,
                              device=device)  # pass the algorithm config here, since some of the sims need it
    optimizer = _get_optimizer(config.optimizer, simulator, loading_config=loading_config, device=device)
    algorithm = algorithm_class(config, simulator, env, optimizer, loading_config=loading_config, device=device)

    return algorithm


def _get_optimizer(config: ConfigDict, simulator: AbstractSimulator, loading_config: ConfigDict, device: str) -> torch.optim.Optimizer:
    parameters, model_types = simulator.get_parameter_lists_for_optimizer()
    lr = {
        "gnn": config.gnn_learning_rate,
        "decoder": config.decoder_learning_rate
    }
    if "transformer_learning_rate" in config:
        lr["transformer"] = config.transformer_learning_rate


    gnn_optimizer = torch.optim.Adam([
        {"params": param, "lr": lr[model_type]} for param, model_type in zip(parameters, model_types) if model_type != "transformer"
    ])

    if "transformer" in model_types:
        transformer_index = model_types.index("transformer")
        trafo_optimizer = torch.optim.AdamW(parameters[transformer_index], lr=lr["transformer"], weight_decay=config.transformer_weight_decay)
        optimizer = MultipleOptimizer(gnn=gnn_optimizer, transformer=trafo_optimizer)
    else:
        optimizer = gnn_optimizer



    # optimizer = torch.optim.Adam([
    #     {"params": simulator.gnn.parameters(), "lr": config.gnn_learning_rate},
    #     {"params": simulator.decoder.parameters(), "lr": config.decoder_learning_rate}
    # ])
    if loading_config.enable_loading:
        state_dict_name = "optimizer"
        epoch = get_checkpoint_iteration(loading_config, state_dict_name=state_dict_name)
        optimizer.load_state_dict(
            torch.load(os.path.join(loading_config.checkpoint_path, f"{state_dict_name}_{epoch}.pt"),
                       map_location=device))
    return optimizer
