"""
Dataset transformation with selection strategies.
"""
import json
import random
import argparse
from pathlib import Path
from typing import Callable, Dict, List

import numpy as np
import pandas as pd
from jinja2 import Template
from pandarallel import pandarallel

from selection import k_center, facility_location
from distance import cosine
from similarity import cosine_sim
from syntax import syntax_distance, ast_coverage
from utils import normalize_text


# Strategy registry
STRATEGIES: Dict[str, Callable] = {}


def strategy(name: str):
    """Register a selection strategy."""
    def decorator(fn):
        STRATEGIES[name] = fn
        return fn
    return decorator


# Helpers

RESPONSE_PREFIX = "Here is the correct Python program:\n"


def build_samples(df: pd.DataFrame, indices: pd.Series, template: Template) -> List[Dict]:
    """Build samples from selected indices."""
    samples = []
    for i, row in df.iterrows():
        instruction = normalize_text(row["instruction"])
        for idx in indices[i]:
            code = normalize_text(row["candidate"][idx], 0)
            samples.append({
                "instruction": instruction,
                "response": RESPONSE_PREFIX + template.render(code=code)
            })
    return samples


def parallel_select(df: pd.DataFrame, budget: int, fn: Callable, template: Template) -> List[Dict]:
    """Apply selection in parallel."""
    def select(row):
        n = len(row["candidate"])
        return list(range(n)) if n <= budget else fn(row, budget)

    indices = df.parallel_apply(select, axis=1)
    return build_samples(df, indices, template)


# Strategies

@strategy("all")
def select_all(df, args, template):
    indices = df.apply(lambda r: list(range(len(r["candidate"]))), axis=1)
    return build_samples(df, indices, template)


@strategy("random")
def select_random(df, args, template):
    rng = random.Random(args.seed)

    def select(row):
        n = len(row["candidate"])
        k = min(args.budget, n)
        return rng.sample(range(n), k) if k > 0 else []

    indices = df.apply(select, axis=1)
    return build_samples(df, indices, template)


@strategy("cosine_center")
def select_cosine_center(df, args, template):
    fn = lambda r, k: k_center(r["embedding"].tolist(), k, cosine)
    return parallel_select(df, args.budget, fn, template)


@strategy("cosine_fl")
def select_cosine_fl(df, args, template):
    fn = lambda r, k: facility_location(r["embedding"].tolist(), k, cosine_sim)
    return parallel_select(df, args.budget, fn, template)


@strategy("syntax_center")
def select_syntax_center(df, args, template):
    fn = lambda r, k: k_center(r["candidate"].tolist(), k, syntax_distance)
    return parallel_select(df, args.budget, fn, template)


@strategy("ast_coverage")
def select_ast_coverage(df, args, template):
    fn = lambda r, k: ast_coverage(r["candidate"].tolist(), k)
    return parallel_select(df, args.budget, fn, template)


@strategy("cluster")
def select_cluster(df, args, template):
    from sklearn.cluster import KMeans
    from sklearn.metrics import pairwise_distances_argmin_min

    def fn(row, k):
        emb = np.array(row["embedding"].tolist())
        kmeans = KMeans(n_clusters=k, random_state=args.seed, n_init="auto")
        kmeans.fit(emb)
        closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, emb)
        return sorted(closest.tolist())

    return parallel_select(df, args.budget, fn, template)


@strategy("herding")
def select_herding(df, args, template):
    def fn(row, k):
        emb = np.array(row["embedding"].tolist())
        mean = emb.mean(axis=0)
        selected = []
        mask = np.zeros(len(emb), dtype=bool)
        total = np.zeros_like(mean)

        for t in range(1, k + 1):
            residual = mean * t - total
            dists = np.linalg.norm(emb - residual, axis=1)
            dists[mask] = np.inf
            best = np.argmin(dists)
            selected.append(int(best))
            mask[best] = True
            total += emb[best]

        return sorted(selected)

    return parallel_select(df, args.budget, fn, template)


@strategy("difficulty")
def select_difficulty(df, args, template):
    def fn(row, k, eps=1e-12):
        resp = np.asarray(row["response_loss"], dtype=np.float64)
        cond = np.asarray(row["condition_loss"], dtype=np.float64)
        score = cond / (resp + eps)
        return np.argsort(score)[-k:][::-1].tolist()

    return parallel_select(df, args.budget, fn, template)


# CLI

def main():
    parser = argparse.ArgumentParser(description="Transform dataset")
    parser.add_argument("--strategy", required=True, choices=list(STRATEGIES.keys()))
    parser.add_argument("--input", required=True, help="Input parquet")
    parser.add_argument("--output", help="Output JSONL")
    parser.add_argument("--template", default="template.jinja", help="Jinja template")
    parser.add_argument("--budget", type=int, default=11)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max-samples", type=int)
    parser.add_argument("--workers", type=int, default=64)
    args = parser.parse_args()

    pandarallel.initialize(nb_workers=args.workers, progress_bar=True)

    template = Template(Path(args.template).read_text())
    df = pd.read_parquet(args.input)
    print(f"Loaded {len(df)} problems")

    samples = STRATEGIES[args.strategy](df, args, template)
    print(f"Generated {len(samples)} samples")

    if args.max_samples and len(samples) > args.max_samples:
        samples = random.Random(args.seed).sample(samples, args.max_samples)
        print(f"Sampled to {len(samples)}")

    output = args.output or f"{args.strategy}_b{args.budget}_s{args.seed}.jsonl"
    with open(output, "w") as f:
        for item in samples:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

    print(f"Saved to {output}")


if __name__ == "__main__":
    main()
