"""Utilities and functions to construct different efficient heads."""

import os
from dataclasses import dataclass
from typing import Any, Dict, Literal

from efficient_heads.flash_head import get_flash_head_pipeline
from efficient_heads.midx_head import get_midx_pipeline
from efficient_heads.pipeline import get_standard_pipeline
from efficient_heads.svd_softmax import get_svd_softmax_pipeline
from efficient_heads.vocab_pruning import get_pruned_pipeline

PIPELINE_CONSTRUCTORS = {
    "standard": get_standard_pipeline,
    "flash_head": get_flash_head_pipeline,
    "vocab_pruning": get_pruned_pipeline,
    "svd_softmax": get_svd_softmax_pipeline,
    "midx": get_midx_pipeline,
}


@dataclass
class PipelineConfig:
    """Configuration settings for building efficient head pipelines."""

    kwargs: Dict[str, Any]
    head_type: Literal[
        "standard",
        "flash_head",
        "vocab_pruning",
        "svd_softmax",
        "midx",
    ]


def get_clustering_cache_dir(n_clusters):
    """Get the clustering cache from either Google drive or locally"""
    primary_path = "Llama-3.2-1B-Instruct-cluster-{n_clusters}/"
    fallback_path = "Llama-3.2-1B-Instruct-cluster-{n_clusters}/"
    fallback_path = os.path.expanduser(fallback_path)
    cache_path = None
    if os.path.exists(primary_path):
        cache_path = primary_path
    elif os.path.exists(fallback_path):
        cache_path = fallback_path
    else:
        raise FileNotFoundError(
            f"Clustering cache directory not found in either location:\n"
            f" - {primary_path}\n"
            f" - {fallback_path}\n"
            f"Please set CLUSTERING_CACHE_DIR manually."
        )
    return cache_path
