#!/usr/bin/env python3
import os, json, argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Any, List, Tuple, Optional
from tqdm import tqdm
import google.generativeai as genai
from google.generativeai import types
#from google import genai
#from google.genai import types

def grid_shape(g: List[List[int]]) -> Tuple[int,int]:
    return len(g), (len(g[0]) if g else 0)

def grid_to_str(g: List[List[int]]) -> str:
    return "[\n" + ",\n".join(str(r) for r in g) + "\n]"

def build_prompt(inp: List[List[int]], out: List[List[int]], k: int) -> str:
    H, W = grid_shape(inp)
    N = len(out)
    return f"""
You are given an ARC-style transformation task.

Goal:
Produce exactly {k} intermediate steps that transform the INPUT into the OUTPUT via logical, incremental changes.

Formatting (MANDATORY):
1) Return ONLY a single JSON object with a single key "steps".
2) "steps" MUST be an array of length {k}. Each element is an {N}×{N} grid (list of {N} lists, each of length {N}) of integers.
3) Do NOT include the INPUT or OUTPUT grids in the JSON; only the {k} intermediate steps.

Initialization when sizes differ:
- If the INPUT size differs from the OUTPUT size, FIRST create an internal initialization at the OUTPUT size {N}×{N}.
- Treat that initialization as the conceptual starting point that precedes steps[0].

Behavior constraints:
- Minimal, monotonic moves toward the OUTPUT under a consistent transformation.
- Keep {N}×{N} dimensions identical across all steps.

Return JSON EXACTLY in this shape:
{{
  "steps": [
    [[...],[...],...],
    [[...],[...],...],
    ...
  ]
}}

INPUT (size {H}×{W}):
{grid_to_str(inp)}

OUTPUT (size {N}×{N} with N={N}):
{grid_to_str(out)}

Number of intermediate steps K = {k}
""".strip()

def build_schema_for_k_array() -> genai.protos.Schema:
    grid_schema = genai.types.Schema(
        type = genai.types.Type.ARRAY,
        items = genai.types.Schema(
            type = genai.types.Type.ARRAY,
            items = genai.types.Schema(type=genai.types.Type.INTEGER),
        ),
    )
    steps_array = genai.types.Schema(type=genai.types.Type.ARRAY, items=grid_schema)
    return genai.types.Schema(
        type = genai.types.Type.OBJECT,
        required = ["steps"],
        properties = {"steps": steps_array},
    )

def call_model(prompt: str, model: str, thinking_budget: int) -> Dict[str, Any]:
    model_obj = genai.GenerativeModel(model)
    resp = model_obj.generate_content(
        prompt,
        generation_config=types.GenerationConfig(
            response_mime_type="application/json",
        )
    )
    return json.loads(resp.text)


# def call_model(client: genai.Client, prompt: str, model: str, thinking_budget: int) -> Dict[str, Any]:
#     resp = client.models.generate_content(
#         model=model,
#         contents=[types.Content(role="user", parts=[types.Part.from_text(text=prompt)])],
#         config=types.GenerateContentConfig(
#             thinking_config = types.ThinkingConfig(thinking_budget=thinking_budget),
#             response_mime_type="application/json",
#             response_schema=build_schema_for_k_array(),
#         ),
#     )
#     return json.loads(resp.text)

def load_dataset(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        return json.load(f)

def iter_examples(chal: Dict[str, Any]):
    for split in ["train", "test"]:
        if split not in chal: continue
        for idx, ex in enumerate(chal[split]):
            if "input" in ex and "output" in ex:
                yield split, idx, ex["input"], ex["output"]

def steps_dict_to_list(steps_obj: Any) -> List[List[List[int]]]:
    if isinstance(steps_obj, list):
        return steps_obj
    if isinstance(steps_obj, dict):
        ordered, i = [], 0
        while f"step_{i}" in steps_obj:
            ordered.append(steps_obj[f"step_{i}"])
            i += 1
        if not ordered:
            raise ValueError("steps dict missing step_0/step_1/...")
        return ordered
    raise TypeError("Unexpected steps structure")

def ensure_k_steps(steps: List[List[List[int]]], k: int) -> List[List[List[int]]]:
    steps = steps[:k]
    if steps and len(steps) < k:
        steps += [steps[-1]] * (k - len(steps))
    return steps

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--train_json", required=True,
                    help="Path to the ARC training dataset JSON file.")
    ap.add_argument("--out_jsonl", default="arc_transitions_k10.jsonl",
                    help="Append-only JSONL with one object per example.")
    ap.add_argument("--k", type=int, default=10)
    ap.add_argument("--workers", type=int, default=8)
    ap.add_argument("--model", default="gemini-2.5-pro")
    ap.add_argument("--thinking_budget", type=int, default=4096)
    ap.add_argument("--max", type=int, default=400)
    ap.add_argument("--checkpoint", default="", help="Optional compact checkpoint JSON path.")
    ap.add_argument("--checkpoint_every", type=int, default=100)
    ap.add_argument("--with_metadata", action="store_true",
                    help="If set, include debug fields instead of wrapping by ID.")
    ap.add_argument("--out_nested", default="",
                    help="Also write a SINGLE nested JSON that mirrors the ARC dataset with steps.")
    args = ap.parse_args()

    api_key = 'AIzaSyB_QpvWduKeIXK-88CaCPQK973UBsItsK0'#os.environ.get("GEMINI_API_KEY")
    if not api_key:
        raise SystemExit("Please set GEMINI_API_KEY")
    #client = genai.Client(api_key=api_key)
    genai.configure(api_key=api_key)


    train = load_dataset(args.train_json)

    tasks = []
    for chal_id, chal in train.items():
        for split, idx, in_grid, out_grid in iter_examples(chal):
            tasks.append((chal_id, split, idx, in_grid, out_grid))
    #if args.max > 0:
    #    tasks = tasks[:args.max]

    os.makedirs(os.path.dirname(os.path.abspath(args.out_jsonl)) or ".", exist_ok=True)
    out_f = open(args.out_jsonl, "a", encoding="utf-8")

    completed = successes = failures = 0
    nested_acc: Optional[Dict[str, Dict[str, List[Dict[str, Any]]]]] = {} if args.out_nested else None

    def worker(task):
        chal_id, split, idx, in_grid, out_grid = task
        prompt = build_prompt(in_grid, out_grid, args.k)
        rec_trio = {"input": in_grid, "output": out_grid}
        meta = {"challenge_id": chal_id, "split": split, "index": idx}
        try:
            #result = call_model(client, prompt, args.model, args.thinking_budget)
            result = call_model(prompt, args.model, args.thinking_budget)
            steps = ensure_k_steps(steps_dict_to_list(result.get("steps")), args.k)
            rec_trio["steps"] = steps
            meta["success"] = True
        except Exception as e:
            meta["success"] = False
            meta["error"] = str(e)
        return rec_trio, meta

    try:
        with ThreadPoolExecutor(max_workers=args.workers) as pool:
            futures = [pool.submit(worker, t) for t in tasks]
            with tqdm(total=len(futures), desc="Generating transitions", unit="ex") as pbar:
                for fut in as_completed(futures):
                    rec, meta = fut.result()
                    completed += 1
                    successes += 1 if meta.get("success") else 0
                    failures += 0 if meta.get("success") else 1

                    if args.with_metadata:
                        out_obj = {**meta, **rec}
                    else:
                        out_obj = {meta["challenge_id"]: {meta["split"]: [rec]}}

                    out_f.write(json.dumps(out_obj, ensure_ascii=False) + "\n")
                    out_f.flush()

                    if nested_acc is not None and meta.get("success"):
                        nid = meta["challenge_id"]; split = meta["split"]
                        nested_acc.setdefault(nid, {}).setdefault(split, []).append(rec)

                    if args.checkpoint and (completed % args.checkpoint_every == 0):
                        ckpt = {"completed": completed, "successes": successes, "failures": failures,
                                "k": args.k, "model": args.model, "thinking_budget": args.thinking_budget,
                                "out_jsonl": args.out_jsonl}
                        with open(args.checkpoint, "w", encoding="utf-8") as cf:
                            json.dump(ckpt, cf, ensure_ascii=False, indent=2)

                    pbar.set_postfix_str(f"ok={successes} fail={failures}")
                    pbar.update(1)
    finally:
        out_f.close()

    if args.out_nested:
        nested = {}
        for chal_id, chal in train.items():
            nested.setdefault(chal_id, {})
            nested[chal_id]["train"] = chal.get("train", [])
            nested[chal_id]["test"]  = chal.get("test", [])
        for chal_id, splits in (nested_acc or {}).items():
            for split_name, exs in splits.items():
                nested.setdefault(chal_id, {})
                nested[chal_id][split_name] = exs

        os.makedirs(os.path.dirname(os.path.abspath(args.out_nested)) or ".", exist_ok=True)
        with open(args.out_nested, "w", encoding="utf-8") as f:
            json.dump(nested, f, ensure_ascii=False, indent=4)

    if args.checkpoint:
        ckpt = {"completed": completed, "successes": successes, "failures": failures,
                "k": args.k, "model": args.model, "thinking_budget": args.thinking_budget,
                "out_jsonl": args.out_jsonl}
        with open(args.checkpoint, "w", encoding="utf-8") as cf:
            json.dump(ckpt, cf, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    main()
