class InstantiateModel:
    def __init__(self, model_name):
        self.model_name = model_name
    def __call__(self, *args, **kwargs):
        if self.model_name == "gemma2_find_fv":
            from patching_gemma.models.gemma2_find_fv import Gemma2FindFVHeads
            return Gemma2FindFVHeads(*args, **kwargs)
        if self.model_name == "gemma2_ablate_edges":
            from patching_gemma.models.gemma2_ablate_edges import Gemma2AblateEdges
            return Gemma2AblateEdges(*args, **kwargs)
        if self.model_name in [f"gemma2_{m_size}b_ablate_edges" for m_size in [9, 27]]:
            from patching_gemma.models.gemma2_bigger_model_ablate_edges import Gemma2BiggerModelAblateEdges
            model_size = ("9b" if "9b" in self.model_name else "27b" if "27b" in self.model_name else None)
            assert model_size is not None
            return Gemma2BiggerModelAblateEdges(*args, **kwargs, model_size=model_size)
        if self.model_name == "gemma2_calc_accuracy":
            from patching_gemma.models.gemma2_calculate_accuracy import Gemma2CalculateAccuracy
            return Gemma2CalculateAccuracy(*args, **kwargs)
        if self.model_name in [f"gemma2_{m_size}b_calc_accuracy" for m_size in [9, 27]]:
            from patching_gemma.models.gemma2_bigger_model_calculate_accuracy import Gemma2BiggerModelCalculateAccuracy
            model_size = ("9b" if "9b" in self.model_name else "27b" if "27b" in self.model_name else None)
            assert model_size is not None
            return Gemma2BiggerModelCalculateAccuracy(*args, **kwargs, model_size=model_size)
        if self.model_name == "gemma2_ablate_edges_differently_for_different_tokens":
            from patching_gemma.models.gemma2_ablate_edges_differently_for_different_tokens import Gemma2AblateEdgesDifferentlyForDifferentTokens
            return Gemma2AblateEdgesDifferentlyForDifferentTokens(*args, **kwargs)
        if self.model_name == "gemma2_prune_heads_in_circuits":
            from patching_gemma.models.gemma2_prune_heads_in_circuits import Gemma2PruneHeadsInCircuits
            return Gemma2PruneHeadsInCircuits(*args, **kwargs)
        if self.model_name == "gemma2_prune_edges_in_circuits":
            from patching_gemma.models.gemma2_prune_edges_in_circuits import Gemma2PruneEdgesInCircuits
            return Gemma2PruneEdgesInCircuits(*args, **kwargs)
        if self.model_name in [f"{model_name}_calc_accuracy" for model_name in ["llama3", "smollm"]]:
            from patching_gemma.models.llama3_calculate_accuracy import Llama3CalculateAccuracy
            model_name = self.model_name[:self.model_name.find("_calc_accuracy")]
            return Llama3CalculateAccuracy(model_name, *args, **kwargs)
        if self.model_name == "phi2_calc_accuracy":
            from patching_gemma.models.phi_calculate_accuracy import PhiCalculateAccuracy
            return PhiCalculateAccuracy(*args, **kwargs)
        if self.model_name in [f"{model_name}_ablate_edges" for model_name in ["llama3", "smollm"]]:
            from patching_gemma.models.llama3_ablate_edges import Llama3AblateEdges
            model_name = self.model_name[:self.model_name.find("_ablate_edges")]
            return Llama3AblateEdges(model_name, *args, **kwargs)
        if self.model_name == "phi2_ablate_edges":
            from patching_gemma.models.phi_ablate_edges import PhiAblateEdges
            return PhiAblateEdges(*args, **kwargs)
        if self.model_name == "qwen2_calc_accuracy":
            from patching_gemma.models.qwen2_calculate_accuracy import Qwen2CalculateAccuracy
            return Qwen2CalculateAccuracy(*args, **kwargs)
        if self.model_name == "gemma2_ablate_edges_differently_for_different_tokens_only_pp_fewshots":
            from patching_gemma.models.gemma2_ablate_only_from_pp_task import Gemma2AblateEdgesDifferentlyForDifferentTokensOnlyFromPPFewshots
            return Gemma2AblateEdgesDifferentlyForDifferentTokensOnlyFromPPFewshots(*args, **kwargs)
        

NAME_TO_MODEL = {
    "gemma2_find_fv": InstantiateModel("gemma2_find_fv"),
    "gemma2_ablate_edges": InstantiateModel("gemma2_ablate_edges"),
    "gemma2_calc_accuracy": InstantiateModel("gemma2_calc_accuracy"),
    "gemma2_ablate_edges_differently_for_different_tokens": InstantiateModel("gemma2_ablate_edges_differently_for_different_tokens"),
    "gemma2_prune_heads_in_circuits": InstantiateModel("gemma2_prune_heads_in_circuits"),
    "gemma2_prune_edges_in_circuits": InstantiateModel("gemma2_prune_edges_in_circuits"),
    "llama3_calc_accuracy": InstantiateModel("llama3_calc_accuracy"),
    "qwen2_calc_accuracy": InstantiateModel("qwen2_calc_accuracy"),
    "gemma2_ablate_edges_differently_for_different_tokens_only_pp_fewshots": InstantiateModel("gemma2_ablate_edges_differently_for_different_tokens_only_pp_fewshots"),
}

NAME_TO_MODEL.update({
    f"{model_name}_calc_accuracy": InstantiateModel(f"{model_name}_calc_accuracy")
    for model_name in ["llama3", "smollm", "phi2", "gemma2_9b", "gemma2_27b"]
})

NAME_TO_MODEL.update({
    f"{model_name}_ablate_edges": InstantiateModel(f"{model_name}_ablate_edges")
    for model_name in ["llama3", "smollm", "phi2", "gemma2_9b", "gemma2_27b"]
})