#!/usr/bin/env python3
"""
Merge an ARC-style JSONL file (first file) with a companion JSONL that contains
step-wise annotations (second file), per task and per train example.

Output preserves the *first file's* structure and only adds a "steps" field to
each corresponding train example:

first line (task):
{
  "task_name": "22168020",
  "train": [{"input": [...], "output": [...]}, ...],
  "test": [{"input": [...]} , ...]
}

second lines (steps):
{
  "task_name": "22168020",
  "steps": [...]
}
(repeated once per train example for that task; the order matches the first file)

If the second file has three items for the same task_name, they map to train[0], train[1], train[2] respectively.

USAGE:
    python merge_arc_with_steps.py --first first.jsonl --second steps.jsonl --out merged.jsonl

NOTES / ASSUMPTIONS:
- Both inputs are JSON Lines (one JSON object per line).
- Each line in FIRST has a unique "task_name" (or "id") identifying the task.
  If missing, we fall back to implicit ordering labels "task_000000", "task_000001", ...
- Each line in SECOND has a "task_name" and a "steps" payload for one train example.
- The order of SECOND for a given task_name corresponds to the order of the train examples
  in FIRST for that task_name.
"""
import json
import argparse
from collections import defaultdict
from typing import List, Dict, Any, Tuple

def read_jsonl(path: str) -> List[Dict[str, Any]]:
    data = []
    with open(path, 'r', encoding='utf-8') as f:
        for line_no, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid JSON on line {line_no} of {path}: {e}") from e
            data.append(obj)
    return data

def get_task_name(obj: Dict[str, Any], fallback_index: int) -> str:
    for key in ("task_name", "task", "id", "taskId", "name"):
        if key in obj and isinstance(obj[key], str) and obj[key]:
            return obj[key]
    # Fallback: generate a synthetic name based on position
    return f"task_{fallback_index:06d}"

# (In merge_arc_with_steps.py)
# REPLACE this entire function:
def group_steps_by_task(second_items: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
    """
    Parses the 'second' file, which has a complex format like:
    {"task_id": {"train": [{"steps": [...]}]}}
    """
    grouped = defaultdict(list)
    
    for item in second_items:
        # 'item' is like {"abc": {"train": [...]}} or {"abc": {"test": [...]}}
        
        if not isinstance(item, dict) or len(item) == 0:
            # Skip malformed lines (e.g., empty lines)
            continue
            
        # The task_name is the first (and likely only) key
        task_name = list(item.keys())[0]
        payload = item[task_name] # This is {"train": [...]} or {"test": [...]}
        
        if not isinstance(payload, dict):
            continue

        # The merge script only cares about 'train' examples
        train_examples = payload.get("train")
        
        if not isinstance(train_examples, list) or not train_examples:
            continue
            
        # The generator creates one line per example, so the list *should* have 1 element
        # We'll just grab the first one.
        example_obj = train_examples[0]
        if not isinstance(example_obj, dict):
            continue
            
        # Find the actual steps payload
        steps_payload = example_obj.get("steps")
        
        if steps_payload is not None:
            # Re-create the simple object structure that the 'merge' function expects
            steps_obj = {"steps": steps_payload}
            grouped[task_name].append(steps_obj)
            
    return grouped

def merge(first_items: List[Dict[str, Any]], grouped_steps: Dict[str, List[Dict[str, Any]]]) -> Tuple[List[Dict[str, Any]], List[str]]:
    warnings = []
    merged = []

    for i, task_obj in enumerate(first_items):
        tname = get_task_name(task_obj, fallback_index=i)

        # normalize presence of "train"
        train_list = task_obj.get("train", [])
        if not isinstance(train_list, list):
            warnings.append(f"[{tname}] 'train' field is not a list; skipping steps injection for this task.")
            merged.append(task_obj)
            continue

        steps_for_task = grouped_steps.get(tname, [])
        if not steps_for_task:
            warnings.append(f"[{tname}] No steps found in second file; task will remain unchanged.")
            merged.append(task_obj)
            continue

        # If the # of steps entries doesn't match # of train examples, we'll use min()
        n_train = len(train_list)
        n_steps_lines = len(steps_for_task)
        n_map = min(n_train, n_steps_lines)

        if n_steps_lines != n_train:
            warnings.append(f"[{tname}] steps entries ({n_steps_lines}) != train examples ({n_train}); mapping first {n_map} only.")

        # Inject steps into corresponding train examples
        for k in range(n_map):
            steps_payload = steps_for_task[k].get("steps")
            if steps_payload is None:
                warnings.append(f"[{tname}] steps line #{k} has no 'steps' field; skipping.")
                continue
            # Ensure the slot exists and is a dict
            if not isinstance(train_list[k], dict):
                warnings.append(f"[{tname}] train[{k}] is not an object; cannot attach steps.")
                continue
            train_list[k]["steps"] = steps_payload

        task_obj["train"] = train_list
        # mirror the name for downstream clarity if not present
        if "task_name" not in task_obj:
            task_obj["task_name"] = tname

        merged.append(task_obj)

    return merged, warnings

def write_jsonl(items: List[Dict[str, Any]], path: str) -> None:
    with open(path, 'w', encoding='utf-8') as f:
        for obj in items:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

def write_json_dict(items: List[Dict[str, Any]], path: str) -> None:
    """
    Writes a list of task objects into a single JSON dictionary,
    using 'task_name' as the key.
    """
    output_dict = {}
    for obj in items:
        # Use pop to get the name and remove it from the object,
        # which matches the standard ARC dictionary format.
        task_name = obj.pop("task_name", None) 
        if task_name:
            output_dict[task_name] = obj
        else:
            print(f"Warning: Skipping task in output dict, 'task_name' missing: {str(obj)[:100]}")
            
    with open(path, 'w', encoding='utf-8') as f:
        # indent=2 makes it human-readable
        json.dump(output_dict, f, ensure_ascii=False, indent=2)

def main():
    # ap = argparse.ArgumentParser(description="Merge ARC JSONL with steps JSONL.")
    # ap.add_argument("--first", required=True, help="Path to first JSONL (base ARC tasks).")
    # ap.add_argument("--second", required=True, help="Path to second JSONL (steps per train example).")
    # ap.add_argument("--out", required=True, help="Path to output merged JSONL.")
    # args = ap.parse_args()

    FILE1_PATH = 'kaggle/combined/arc-agi_training2_transitions_k10.jsonl'
    FILE2_PATH = 'kaggle/combined/arc-agi_training2_challenges.json'

    # --- !! IMPORTANT !! ---
    # The output is a single .json file, not .jsonl
    OUTPUT_PATH = 'kaggle/combined/arc-agi_training2_converted_k10.json'


    first_items = read_jsonl(FILE2_PATH)
    second_items = read_jsonl(FILE1_PATH)

    # --- BEGIN FIX ---
    # Check if 'first_items' is a list containing a *single* giant dictionary,
    # which is the format: [{"task1": {...}, "task2": {...}, ...}]
    # This happens if the --first file is a JSONL with one line
    # containing the entire ARC training dictionary.
    if (len(first_items) == 1 and 
        isinstance(first_items[0], dict) and
        # A good heuristic: a real task obj has 'train', the dict-obj does not
        "train" not in first_items[0]):
        
        print(f"Detected single-line dictionary format in --first file (1 item found). Converting...")
        tasks_dict = first_items[0]
        converted_first_items = []
        for task_name, task_obj in tasks_dict.items():
            if isinstance(task_obj, dict):
                task_obj["task_name"] = task_name # Inject the name
                converted_first_items.append(task_obj)
        
        first_items = converted_first_items # Overwrite with the corrected list
        print(f"Converted to {len(first_items)} task objects.")
    # --- END FIX ---
    
    grouped_steps = group_steps_by_task(second_items)
    merged_items, warnings = merge(first_items, grouped_steps)

    write_json_dict(merged_items, OUTPUT_PATH)

    # Print a concise report
    print(f"Merged {len(first_items)} tasks from FIRST with {len(second_items)} step entries from SECOND.")
    if warnings:
        print("\nWarnings:")
        for w in warnings:
            print(" -", w)

if __name__ == "__main__":
    main()
