from typing import Union
import torch
from trainer.base_obj import RLObjective
from trainer.base_hparams import OffPolicyRLHyperParameterSpace, OnPolicyRLHyperParameterSpace
from trainer.RLPolicy import PPOPolicy, LLM_Policy, LLM_Instruct_Policy
from trainer.RLObj import PPOObjective, LLM_Objective, LLM_Instruct_Objective
from trainer.RLHparams import PPOHyperParams, LLM_HyperParams, LLM_Instruct_HyperParams

def policy_load(policy, ckpt_path: str, device: str, is_train: bool = False):
    if ckpt_path is not None:
        ckpt = torch.load(ckpt_path, map_location=torch.device(device))
        ckpt = ckpt if ckpt_path.endswith("policy.pth") else ckpt["model"]  # policy.pth and ckpt.pth has different keys
        policy.load_state_dict(ckpt)
    if is_train:
        policy.train()
    else:
        policy.eval()
    return policy

policyLOOKUP = {
    "ppo": {"hparam": PPOHyperParams, "policy": PPOPolicy, "obj": PPOObjective, "type": "continuous"},
    "llm": {"hparam": LLM_HyperParams, "policy": LLM_Policy, "obj": LLM_Objective, "type": "continuous"},
    "llm-instruct": {"hparam": LLM_Instruct_HyperParams, "policy": LLM_Instruct_Policy, "obj": LLM_Instruct_Objective, "type": "continuous"},
}


def get_hparam_class(algo_name: str, offline) -> Union[OffPolicyRLHyperParameterSpace.__class__, OnPolicyRLHyperParameterSpace.__class__]:
    algo_name = algo_name.lower()
    if offline:
        raise NotImplementedError("Offline RL is not supported yet")
    else:
        return policyLOOKUP[algo_name]["hparam"]


def get_obj_class(algo_name: str, offline) -> RLObjective.__class__:
    algo_name = algo_name.lower()
    if offline:
        raise NotImplementedError("Offline RL is not supported yet")
    else:
        return policyLOOKUP[algo_name]["obj"]


def get_policy_type(algo_name: str, offline: bool) -> str:
    algo_name = algo_name.lower()
    if offline:
        raise NotImplementedError("Offline RL is not supported yet")
    else:
        return policyLOOKUP[algo_name]["type"]