"""Utilities and functions to construct different efficient heads."""

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Literal

from efficient_heads.fgd import get_fgd_pipeline
from efficient_heads.flash_head import get_flash_head_pipeline
from efficient_heads.midx_head import get_midx_pipeline
from efficient_heads.pipeline import get_standard_pipeline
from efficient_heads.svd_softmax import get_svd_softmax_pipeline
from efficient_heads.vocab_pruning import get_pruned_pipeline


def get_spherical_k_means_pipeline(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    n_clusters=100,
    **kwargs,
):
    """Get the L2S pipeline."""
    return get_flash_head_pipeline(
        model_id, n_probes=1, n_clusters=n_clusters, **kwargs
    )


PIPELINE_CONSTRUCTORS = {
    "standard": get_standard_pipeline,
    "flash_head": get_flash_head_pipeline,
    "vocab_pruning": get_pruned_pipeline,
    "svd_softmax": get_svd_softmax_pipeline,
    "midx": get_midx_pipeline,
    "spherical_k_means": get_spherical_k_means_pipeline,
    "fgd": get_fgd_pipeline,
}


@dataclass
class PipelineConfig:
    """Configuration settings for building efficient head pipelines."""

    kwargs: Dict[str, Any]
    head_type: Literal[
        "standard",
        "flash_head",
        "vocab_pruning",
        "svd_softmax",
        "midx",
        "spherical_k_means",
        "fgd",
    ]


def create_experiments(
    size: str, cluster_cache_dir_path: Path
) -> List[PipelineConfig]:
    """Return a list of pipeline configs of given `size` to be evaluated."""

    pipeline_configs = []

    if size == "standard":
        return [
            PipelineConfig(
                kwargs={},
                head_type="standard",
            ),
        ]

    for n_probes in [256]:
        pipeline_configs.append(
            PipelineConfig(
                kwargs={
                    "n_clusters": 8016,
                    "n_probes": n_probes,
                    "cache_dir": str(
                        cluster_cache_dir_path
                        / "Llama-3.2-1B-Instruct-cluster-8016_eq/"
                    ),
                },
                head_type="flash_head",
            )
        )

    if size == "small":
        return pipeline_configs

    assert size == "large"

    for n_clusters in [1024, 4096, 16384]:
        for n_probes in [128, 256, 512]:
            pipeline_configs.append(
                PipelineConfig(
                    kwargs={
                        "n_clusters": n_clusters,
                        "n_probes": n_probes,
                        "cache_dir": str(
                            cluster_cache_dir_path
                            / f"Llama-3.2-1B-Instruct-cluster-{n_clusters}/"
                        ),
                    },
                    head_type="flash_head",
                )
            )

    pipeline_configs.append(
        PipelineConfig(
            kwargs={},
            head_type="fgd",
        ),
    )

    pipeline_configs.extend(
        [
            PipelineConfig(
                kwargs={
                    "n_clusters": 100,
                    "cache_dir": str(
                        cluster_cache_dir_path
                        / "Llama-3.2-1B-Instruct-cluster-100/"
                    ),
                },
                head_type="spherical_k_means",
            ),
            PipelineConfig(
                kwargs={
                    "n_clusters": 200,
                    "cache_dir": str(
                        cluster_cache_dir_path
                        / "Llama-3.2-1B-Instruct-cluster-200/"
                    ),
                },
                head_type="spherical_k_means",
            ),
        ]
    )

    pipeline_configs.extend(
        [
            PipelineConfig(
                kwargs={"vocab_size": 64000},
                head_type="vocab_pruning",
            ),
            PipelineConfig(
                kwargs={
                    "window": 256,
                    "top_n": 12000,
                },
                head_type="svd_softmax",
            ),
        ]
    )

    return pipeline_configs
