"""SVD softmax layer implementation"""

import json

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import pipeline

from efficient_heads.eval_outputs import TextEvaluator as Evaluator
from efficient_heads.svd_softmax import SVDSoftmaxHead

torch.manual_seed(0)


if __name__ == "__main__":
    BASELINE_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
    BATCH_SIZE = 32
    NUM_BATCHES = 4
    DATASET_ID = "tatsu-lab/alpaca"
    dataset = load_dataset(DATASET_ID, split="train")
    dataloader = DataLoader(
        dataset,
        collate_fn=lambda batch: [
            item["instruction"] + "\n\n" + item.get("input", "")
            for item in batch
        ],
        batch_size=BATCH_SIZE,
    )
    baseline_model_pipeline = pipeline(
        "text-generation", model=BASELINE_MODEL, batch_size=BATCH_SIZE
    )

    all_results = []
    for _window in [64, 128, 256, 512, 1024, 2048]:
        for _top_n in [4000, 8000, 16000, 32000, 64000, 128000]:
            print(
                f"Running evaluation for window={_window} and top_n={_top_n}."
            )
            model_pipeline = pipeline(
                "text-generation",
                model=BASELINE_MODEL,
                batch_size=BATCH_SIZE,
            )
            model_pipeline.model.lm_head = SVDSoftmaxHead(
                weights=model_pipeline.model.lm_head.weight,
                window=_window,
                top_n=_top_n,
            )

            evaluator = Evaluator(
                base_model_pipeline=baseline_model_pipeline,
                model_pipeline=model_pipeline,
            )
            results = evaluator.evaluate(dataloader, num_batches=NUM_BATCHES)
            results.update({"window": _window, "top_n": _top_n})
            all_results.append(results)
            print(all_results)

    with open(
        "./svd_softmax_results.json",
        "w",
        encoding="utf-8",
    ) as f:
        json.dump(all_results, f, indent=2)
