import csv
import json
import os
from typing import Any, Dict, List

from datasets import Dataset, DatasetDict
from huggingface_hub import login

from eliciting_contexts.benchmark.internal.sae.raw_data_multiple import multiple_saes
from eliciting_contexts.benchmark.internal.sae.raw_data_single import single_saes
from eliciting_contexts.benchmark.internal.tiny_stories.data import upload_hf_dataset

login()

SINGLE_SAE_COLUMN_NAMES = [
    "density",
    "vocab_diversity",
    "local_vs_global",
    "tags",
    "necessary_context",
    "necessary_condition",
    "success_criterion",
    "human_explanation",
    "feature_grade",
    "neuronpedia_id",
    "index",
]

MULTIPLE_SAE_COLUMN_NAMES = [
    "category",
    "property1",
    "property2",
    "neuronpedia_id1",
    "neuronpedia_id2",
    "index1",
    "index2",
]


def process_csv_to_json(csv_path: str, output_path: str):
    """
    Process the SAE features CSV file into a JSON format with all columns.

    Args:
        csv_path: Path to the input CSV file
        output_path: Path where to save the output JSON file
    """
    processed_data = []

    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        fieldnames = reader.fieldnames

        for row in reader:
            # Skip empty rows
            if not row["Feature ID"]:
                continue

            # Create the release string
            layer = row["Layer"]
            release = f"gemma-2-2b/{layer}-gemmascope-res-16k"

            # Clean up density value
            density = row["Density"].split(" ")[0] if row["Density"] else None

            # Clean up feature grade value
            feature_grade = (
                row["Feature Grade"].split(" ")[0] if row["Feature Grade"] else None
            )

            # Create the data entry with all columns
            entry = {
                "density": density,
                "vocab_diversity": row["Vocab Diversity"],
                "local_vs_global": row[
                    fieldnames[2]
                ],  # Use the exact fieldname from CSV
                "tags": row["Tags"] if row["Tags"] else None,
                "necessary_context": (
                    row["Necessary context"] if row["Necessary context"] else None
                ),
                "necessary_condition": (
                    row["Necessary condition"] if row["Necessary condition"] else None
                ),
                "success_criterion": (
                    row["Success Criterion"] if row["Success Criterion"] else None
                ),
                "human_explanation": row["Human explanation"],
                "feature_grade": feature_grade,
                "neuronpedia_id": release,
                "index": int(row["Feature ID"]),
            }
            processed_data.append(entry)

    # Save to JSON file
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(processed_data, f, indent=2)


def upload_single_sae_dataset(
    dataset: List[Dict[str, Any]],
    repo_id: str = "contextmodification/SAE_single_benchmark",
    column_names: list[str] = SINGLE_SAE_COLUMN_NAMES,
):
    """
    Upload a dataset of SAE features to Hugging Face.

    Args:
        dataset: List of dictionaries containing SAE feature data
        repo_id: Hugging Face repository ID
        column_names: List of column names in the dataset
    """
    data_dict = {name: [item[name] for item in dataset] for name in column_names}
    dataset = Dataset.from_dict(data_dict)
    dataset_dict = DatasetDict({"test": dataset})
    print("Created DatasetDict.")

    api_token = os.environ.get("HF_TOKEN")

    upload_hf_dataset(
        dataset_dict,
        repo_id,
        hf_token=api_token,
        private=True,
    )


if __name__ == "__main__":
    # Process the CSV file into JSON
    process_csv_to_json(
        "src/eliciting_contexts/benchmark/internal/sae/data/SAE_Features_Single.csv",
        "src/eliciting_contexts/benchmark/internal/sae/data/SAE_Features_Single.json",
    )

    # Load the processed JSON and upload to Hugging Face
    with open(
        "src/eliciting_contexts/benchmark/internal/sae/data/SAE_Features_Single.json",
        "r",
    ) as f:
        processed_data = json.load(f)

    upload_single_sae_dataset(
        processed_data, repo_id="contextmodification/SAE_single_benchmark"
    )
