#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Preprocess FinQA for program synthesis (consistent with evaluator + chat template).

What this script does
---------------------
- Loads FinQA JSON (train/val/test) where each item has: pre_text, post_text, table, qa{question, program, exe_ans}
- Builds a single-turn chat prompt that ends with **"Program:"** (no "Final Answer:")
- Converts the gold `qa.program` (list or string) to a linearized string of steps like:
    divide(100, 50) add(#0, 10)
  and appends a literal " EOF" token so the model learns to stop.
- Tokenizes with the **same tokenizer as the base model**, using `apply_chat_template(..., add_generation_prompt=True)`
- Creates tensors: input_ids, attention_mask, labels where labels are -100 for prompt tokens and the target program tokens for the answer
- Saves Hugging Face Datasets to disk for train and eval.

Notes
-----
- We purposely do NOT include the numeric final answer in labels; only the program is learned.
- At inference you will parse the generated text back into 5-token steps and then append "EOF" in the JSON file you feed to the evaluator.

Usage
-----
python preprocess_fixed.py \
  --model_name meta-llama/Llama-3.2-1B-Instruct \
  --train_json /path/to/train.json \
  --dev_json   /path/to/dev.json \
  --save_dir   ./finqa_program_chat_v1
"""

import argparse
import json
import re
from pathlib import Path
from typing import Any, Dict, List, Union

from datasets import Dataset, Features, Sequence, Value
from transformers import AutoTokenizer
import torch
import numpy as np

# ----------------------
# Helpers
# ----------------------

def render_table(tbl: Union[List, Dict, str]) -> str:
    """Render table into a compact, model-friendly string."""
    if tbl is None:
        return ""
    if isinstance(tbl, str):
        return tbl
    lines = []
    if isinstance(tbl, list):
        for row in tbl:
            if isinstance(row, list):
                lines.append(" | ".join(str(c) for c in row))
            elif isinstance(row, dict):
                lines.append(" | ".join(f"{k}: {v}" for k, v in row.items()))
            else:
                lines.append(str(row))
    elif isinstance(tbl, dict):
        # some variants store a dict of rows/cols
        for k, v in tbl.items():
            lines.append(f"{k}: {v}")
    return "\n".join(lines)

def linearize_program(program: Union[str, List[str]]) -> str:
    """
    Convert FinQA program into a compact text the model should generate.
    Accepts either a token list like ['divide','(','100',',','50',')', ...] or a string.
    Produces: 'divide(100, 50) add(#0, 10)' (steps separated by spaces).
    """
    if program is None:
        return ""
    if isinstance(program, str):
        # Normalize whitespace
        s = " ".join(program.strip().split())
    else:
        # It's a token list; stitch into a string first
        s = " ".join(program)

    # Convert common spaced forms into op(arg1, arg2)
    # We capture operation names and two comma-separated arguments
    pattern = r'(add|subtract|multiply|divide|exp|greater|table_[a-z_]+)\s*\(\s*([^,\)]+)\s*,\s*([^,\)]+)\s*\)'
    steps = []
    for m in re.finditer(pattern, s, flags=re.IGNORECASE):
        op = m.group(1).lower()
        a1 = m.group(2).strip()
        a2 = m.group(3).strip()
        # remove thousands separators inside numbers
        a1 = a1.replace(",", "")
        a2 = a2.replace(",", "")
        steps.append(f"{op}({a1}, {a2})")
    return " ".join(steps)

def build_prompt(pre_text: List[str], post_text: List[str], table_obj: Any, question: str, program: Union[str, List[str]]):
    pre = " ".join(pre_text) if isinstance(pre_text, list) else (pre_text or "")
    post = " ".join(post_text) if isinstance(post_text, list) else (post_text or "")
    table_str = render_table(table_obj)
    prog = linearize_program(program)
    
    # Clear, minimal formatting
    system_message = (
        "You are a program synthesis model for financial question answering.\n"
        "Given the context (before/after text) and a table, write a reasoning Program as a sequence of operations.\n"
        "Each step must look like: op(arg1, arg2). Use #k to reference the result of step k (starting at #0 after first step).\n"
        "**Example:** `add(divide(100, 2), 50)#0 EOF`"
    )
    user_message = (   
        "Please generate a concise program sequence for the question based on the following information:\n"
        "Context (before table):\n"
        f"{pre}\n\n"
        "Table:\n"
        f"{table_str}\n\n"
        "Context (after table):\n"
        f"{post}\n\n"
        f"Question: {question}\n\n"
        "Please generate the program directly without any additional explanation:"  # <- IMPORTANT: ends with Program:
    )
        
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": prog}
    ]
    
    return messages

def load_jsonl_or_json(path: Path) -> List[Dict]:
    # Basic loader that accepts .json (list) or .jsonl
    text = path.read_text(encoding="utf-8")
    try:
        data = json.loads(text)
        if isinstance(data, dict):
            # sometimes wrapped
            data = data.get("data", [])
    except json.JSONDecodeError:
        # jsonl
        data = [json.loads(line) for line in text.splitlines() if line.strip()]
    return data

def convert(examples: List[Dict]) -> Dataset:
    rows = []
    for ex in examples:
        qa = ex.get("qa", {}) or {}
        question = qa.get("question", "")
        table_obj = ex.get("table", ex.get("table_ori", ""))

        prompt = build_prompt(ex.get("pre_text", ""), ex.get("post_text", ""), table_obj, question, qa.get("program", ""))

        rows.append(prompt)

    ds = Dataset.from_dict({'messages': rows})
    return ds

def load_finqa_dataset(json_path):
    data = load_jsonl_or_json(Path(json_path))
    
    ds_data = convert(data)
    return ds_data
