import json
import os
import argparse
import random

from datasets import load_dataset
from transformers import pipeline


def load_label_map(json_path: str) -> dict:
    if os.path.exists(json_path):
        with open(json_path, 'r') as f:
            return json.load(f)
    return {}

def save_label_map(label_map: dict, json_path: str):
    with open(json_path, 'w') as f:
        json.dump(label_map, f, indent=2)


def infer_labels_via_llm(samples: list, llm):
    """
    Given a few text samples for a class, ask the LLM for a descriptive label.
    """
    joined = "\n".join(f"- {s}" for s in samples)
    prompt = (
        "The following are text examples belonging to the same category:\n"
        f"{joined}\n"
        "Suggest a concise descriptive label for this category (single word or short phrase):"
    )
    resp = llm(prompt, max_length=16, clean_up_tokenization_spaces=True)
    # pipeline returns list of dicts with 'generated_text'
    return resp[0]['generated_text'].strip()

def extract_text_column(dataset):
    # Pick first string column automatically
    for col, feat in dataset.features.items():
        if feat.dtype == "string":
            return col
    return None

def main(task_names, json_path="label_map.json", num_samples=3, model_name="microsoft/Phi-3.5-mini-instruct"):
    llm = pipeline("text2text-generation", model=model_name, device=0 if os.getenv("CUDA_VISIBLE_DEVICES") else -1)
    label_map = load_label_map(json_path)
    for task in task_names:
        if task in label_map:
            print(f"[SKIP] '{task}' already in label_map.json")
            continue
        print(f"[PROCESS] Inferring labels for task: {task}")
        try:
            dataset = load_dataset("glue", task)
        except Exception:
            dataset = load_dataset("super_glue", task)
        raw_labels = set(dataset['train']['label'])
        # Check if labels are descriptive strings or numeric
        if all(isinstance(lbl, str) and not lbl.isdigit() for lbl in raw_labels) and raw_labels:
            mapped = sorted(raw_labels)
        else:
            # Ambiguous or numeric labels :rule-based + LLM
            text_col = extract_text_column(dataset['train'])
            if text_col is None:
                raise RuntimeError("No text column found to sample from.")
            mapped = []
            for lbl in sorted(raw_labels):
                # sample few examples for the label
                examples = [
                    ex[text_col] for ex in dataset['train']
                    if ex['label'] == lbl
                ]
                samples = random.sample(examples, min(num_samples, len(examples)))
                new_label = infer_labels_via_llm(samples, llm)
                print(f"  Mapped raw '{lbl}' → '{new_label}'")
                mapped.append(new_label)

        # Update map
        label_map[task] = mapped

    # Write back to JSON
    save_label_map(label_map, json_path)
    print(f"\n[label_map.json updated at '{json_path}']")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Automatically infer and update label_map.json for given tasks"
    )
    parser.add_argument(
        "--tasks", nargs="+", required=True,
        help="List of task names (e.g., cola mnli cb, etc.)"
    )
    parser.add_argument(
        "--label_map_path", type=str, default="label_map.json",
        help="Path to label_map.json"
    )
    args = parser.parse_args()
    main(args.tasks, json_path=args.label_map_path)
