from .models import OnlineDecisionTransformerModel, DiscreteDTModel, \
    DummyUDTModel, CustomContinuousCritic, MultiHeadContinuousCritic, \
    MultiDomainDiscreteDTModel, CacheDTModel, DiscreteCacheDTModel, DiscreteHelmDTModel, HelmDTModel


MODEL_CLASSES = {
    "DT": OnlineDecisionTransformerModel,
    "ODT": OnlineDecisionTransformerModel,
    "UDT": OnlineDecisionTransformerModel,
    "DummyUDT": DummyUDTModel,
    "DDT": DiscreteDTModel,
    "MDDT": MultiDomainDiscreteDTModel,
    "HelmDT": HelmDTModel,
    "DHelmDT": DiscreteHelmDTModel,
    "CDT": CacheDTModel,
    "DCDT": DiscreteCacheDTModel,
}

AGENT_CLASSES = {
    "DT": None,
    "ODT": None,
    "UDT": None,
    "DummyUDT": None,
    "DDT": None,
    "MDDT": None,
    "HelmDT": None,
    "DHelmDT": None,
    "MDMPDT": None,
    "DMPDT": None,
    "CDT": None,  
    "DCDT": None,
}


def get_model_class(kind):
    assert kind in MODEL_CLASSES, f"Unknown kind: {kind}"
    return MODEL_CLASSES[kind]


def get_agent_class(kind):
    assert kind in AGENT_CLASSES, f"Unknown kind: {kind}"
    # lazy imports only when needed
    if kind in ["DT", "ODT", "HelmDT", "DHelmDT"]:
        from .decision_transformer_sb3 import DecisionTransformerSb3
        AGENT_CLASSES[kind] = DecisionTransformerSb3
    elif kind in ["UDT", "DummyUDT"]:
        from .universal_decision_transformer_sb3 import UDT
        AGENT_CLASSES[kind] = UDT
    elif kind in ["DDT", "MDDT"]:
        from .discrete_decision_transformer_sb3 import DiscreteDecisionTransformerSb3
        AGENT_CLASSES[kind] = DiscreteDecisionTransformerSb3
    elif kind == "CDT":
        from .cache_decision_transformer_sb3 import CacheDecisionTransformerSb3, DiscreteCacheDecisionTransformerSb3
        AGENT_CLASSES[kind] = CacheDecisionTransformerSb3
    elif kind == "DCDT": 
        from .cache_decision_transformer_sb3 import DiscreteCacheDecisionTransformerSb3
        AGENT_CLASSES[kind] = DiscreteCacheDecisionTransformerSb3
    return AGENT_CLASSES[kind]
