from collections import defaultdict
from typing import List, Dict, Any, Sequence, Optional
import numpy as np
import random
import torch
import os,json
from llm_client import load_embed_model, generate, resolve_model_id
from metrics import evaluate_batched_metric
from data_utils import get_task
from prompts import build_prompt
from kmeans_pytorch import kmeans

# --------- 1) Baseline methods (all functions) ----------
def select_random(n: int, k: int, seed: int = 42) -> List[int]:
    random.seed(seed)
    idx = list(range(n))
    random.shuffle(idx)
    return idx[:k]

def select_similarity_topk(
    doc_emb,
    query: str,
    k: int,
    model,
) -> List[int]:
    assert k > 0
    q_emb = model.encode([query])
    sims = model.similarity(q_emb, doc_emb).cpu().numpy()[0]
    return np.argsort(-sims)[:k].tolist()

def select_kmeans_diversity(
    candidates: List[Dict[str, Any]],
    k: int,
    model
) -> List[int]:
    if k <= 0 or len(candidates) == 0:
        return []
    docs = [c["input"] for c in candidates]
    emb = model.encode(docs)
    emb=torch.tensor(emb)# [N,d]
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    cluster_ids, cluster_centers = kmeans(X=emb, num_clusters=k, distance='euclidean', device=device)
    chosen=[]
    for center in range(k):
        cluster = emb[cluster_ids == center]
        distances = torch.norm(cluster - cluster_centers[center], dim=1)
        closest_index = torch.argmin(distances).item()
        chosen.append(torch.where(emb == cluster[closest_index])[0][0].item())

    return chosen


def run_eval(
    task: str,
    model: str = "qwen3-4b",
    method: str = "zero",          #
    k: int =10,
    seed: int = 42,
    emb_m: str = "Qwen/Qwen3-Embedding-0.6B",
):
    trainset = get_task(task, 'train')
    testset = get_task(task, 'test')
    emb_model = load_embed_model(emb_m)
    preds = []
    if method == "kmeans":
        kmeans_idx = select_kmeans_diversity(trainset, k, model=emb_model)
    if method == "sim":
        docs = [c["input"] for c in trainset]
        doc_emb = emb_model.encode(docs)

    backend_id = resolve_model_id(model)
    for ex in testset:
        if method == "zero":
            chosen = []
        elif method == "random":
            idx = select_random(len(trainset), k, seed=seed)
            chosen = [trainset[i] for i in idx]
        elif method == "sim":
            idx = select_similarity_topk(doc_emb, ex["input"], k, model=emb_model)
            chosen = [trainset[i] for i in idx]
        elif method == "kmeans":
            chosen = [trainset[i] for i in kmeans_idx]

        else:
            raise ValueError(f"Unknown method: {method}")
        prompt = build_prompt(task=task, query=ex["input"], examples=chosen)

        pred = generate(backend_id,prompt,max_new_tokens=512) # generate("openai"/“gemini, prompt, max_new_tokens=512)
        
        preds.append(pred)
    print({"task":task, "method": method, "performance": evaluate_batched_metric(task, testset, preds)})
    return {"method": method, "k": k, "performance": evaluate_batched_metric(task, testset, preds)}


if __name__ == "__main__":
    tasks=['gsm8k','gpqa','fp', 'xsum','date','salient']
    model="llama3.1-8b"
    # model="qwen3-4b"
    k=10
    methods=["zero","random","kmeans","sim"]


    out_dir="./results/"
    os.makedirs(out_dir, exist_ok=True)

    for task in tasks:
        results=[]
        for method in methods:
            result=run_eval(
                task=task,
                model=model,
                k=k,
                method=method
            )
            results.append(result)
        json_path = os.path.join(out_dir, f"{task}_baselines_openai_k{k}.json")
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
