#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Unified inference for FinQA program synthesis using vLLM (compatible with base, merged, and PEFT adapters).

Examples:
# Zero-shot / merged model / small-data full FT
python inference_fixed.py \
  --model_path meta-llama/Llama-3.2-1B-Instruct \
  --test_json test.json \
  --out_json predictions_zeroshot.json \
  --max_new_tokens 384 --eg_k 1

# LoRA/QLoRA/Prefix/Prompt Tuning (unmerged adapter)
python inference_fixed.py \
  --base_model meta-llama/Llama-3.2-1B-Instruct \
  --adapter_path /path/to/adapter_dir \
  --test_json test.json \
  --out_json predictions_qlora.json \
  --max_new_tokens 384 --eg_k 1
"""

import argparse
import json
import re
from pathlib import Path
from typing import Any, Dict, List, Union

import math
import torch
from tqdm import tqdm

# vLLM imports
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import os

# Optional PEFT imports guarded at runtime
try:
    from peft import PeftModel, PeftConfig
except Exception:
    PeftModel, PeftConfig = None, None

# ---------------- Helpers ----------------

ALLOWED_OPS = {"add", "subtract", "multiply", "divide", "exp", "greater"}

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 

def is_valid_op(op: str) -> bool:
    op = op.lower()
    return (op in ALLOWED_OPS) or op.startswith("table_")

def render_table(tbl: Union[List, Dict, str]) -> str:
    if tbl is None:
        return ""
    if isinstance(tbl, str):
        return tbl
    lines = []
    if isinstance(tbl, list):
        for row in tbl:
            if isinstance(row, list):
                lines.append(" | ".join(str(c) for c in row))
            elif isinstance(row, dict):
                lines.append(" | ".join(f"{k}: {v}" for k, v in row.items()))
            else:
                lines.append(str(row))
    elif isinstance(tbl, dict):
        for k, v in tbl.items():
            lines.append(f"{k}: {v}")
    return "\n".join(lines)

def build_prompt(pre_text: List[str], post_text: List[str], table_obj: Any, question: str) -> str:
    pre = " ".join(pre_text) if isinstance(pre_text, list) else (pre_text or "")
    post = " ".join(post_text) if isinstance(post_text, list) else (post_text or "")
    table_str = render_table(table_obj)
    system_message = (
        "You are a program synthesis model for financial question answering.\n"
        "Given the context (before/after text) and a table, write a reasoning Program as a sequence of operations.\n"
        "Each step must look like: op(arg1, arg2). Use #k to reference the result of step k (starting at #0 after first step).\n"
        "**Example:** `add(divide(100, 2), 50)#0 EOF`"
    )
    user_message = (   
        "Please generate a concise program sequence for the question based on the following information:\n"
        "Context (before table):\n"
        f"{pre}\n\n"
        "Table:\n"
        f"{table_str}\n\n"
        "Context (after table):\n"
        f"{post}\n\n"
        f"Question: {question}\n\n"
        "Please generate the program directly without any additional explanation:"  # <- IMPORTANT: ends with Program:
    )
    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
    ]

def to_chat(tokenizer, prompt: str) -> str:
    return tokenizer.apply_chat_template(
        prompt,
        add_generation_prompt=True,
        tokenize=False,
        # enable_thinking=True
    )

def parse_raw_to_steps(text: str) -> List[List[str]]:
    # Keep only after "Program:" if present
    if "Program:" in text:
        text = text.split("Program:", 1)[1]
    # Cut off if model continued with other sections
    for stop in ["\n\nQuestion:", "\n\nContext:", "\n\nAnswer:", "\n\nFinal", "Explanation:"]:
        if stop in text:
            text = text.split(stop, 1)[0]
    s = " ".join(text.strip().split())
    if not s:
        return []
    # Regex for op(arg1, arg2)
    pattern = r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(\s*([^,\)]+)\s*,\s*([^,\)]+)\s*\)'
    steps = []
    for m in re.finditer(pattern, s):
        op = m.group(1).strip()
        a1 = m.group(2).strip()
        a2 = m.group(3).strip()
        if not is_valid_op(op):
            continue
        if not a1.startswith("#") and not a1.startswith("const_"):
            a1 = a1.replace(",", "")
        if not a2.startswith("#") and not a2.startswith("const_"):
            a2 = a2.replace(",", "")
        steps.append([op, a1, a2])
    return steps

def steps_to_eval_tokens(steps):
    # FinQA evaluator expects 4 tokens per step: ['op(', 'arg1', 'arg2', ')']
    if not steps:
        return ["EOF"]
    tokens = []
    for (op, a1, a2) in steps:
        tokens.extend([f"{op}(", a1, a2, ")"])
    tokens.append("EOF")
    return tokens

def verify_eval_tokens(tokens):
    # Check 4-token step pattern + final EOF
    if not tokens or tokens[-1] != "EOF":
        return False
    if len(tokens) == 1:
        return True
    body = tokens[:-1]
    if len(body) % 4 != 0:
        return False
    for i in range(0, len(body), 4):
        t0, t1, t2, t3 = body[i:i+4]
        if not t0.endswith("("):
            return False
        op = t0[:-1].lower()
        if not (op in ALLOWED_OPS or op.startswith("table_")):
            return False
        if t3 != ")":
            return False
    return True

def load_json(path: Union[str, Path]) -> List[Dict]:
    text = Path(path).read_text(encoding="utf-8")
    data = json.loads(text)
    if isinstance(data, dict):
        data = data.get("data", [])
    return data

# ---- Execution-guided utilities ----

def _exec_const(tok: str):
    if tok.startswith("const_"):
        try:
            return float(tok.split("_", 1)[1])
        except Exception:
            return None
    return None

def _exec_steps(steps):
    vals = []
    def val(x):
        if x.startswith("#"):
            idx = int(x[1:])
            return vals[idx]
        c = _exec_const(x)
        if c is not None:
            return c
        return float(x)

    for (op, a1, a2) in steps:
        try:
            x, y = val(a1), val(a2)
            if op == "add": z = x + y
            elif op == "subtract": z = x - y
            elif op == "multiply": z = x * y
            elif op == "divide": z = x / y if y != 0 else math.nan
            elif op == "exp": z = x ** y
            elif op == "greater": z = 1.0 if x > y else 0.0
            else: return None
            if math.isnan(z) or math.isinf(z): return None
            vals.append(z)
        except Exception:
            return None
    return vals[-1] if vals else None

def _tokens_to_steps(tokens):
    steps = []
    body = tokens[:-1]  # drop EOF
    for i in range(0, len(body), 4):
        op = body[i][:-1].lower()
        a1, a2 = body[i+1], body[i+2]
        steps.append([op, a1, a2])
    return steps

# ---------------- vLLM Model loading ----------------

def load_model_and_tokenizer(args):
    vllm_kwargs = {
        "model": args.model_path,
        "tensor_parallel_size": torch.cuda.device_count(), 
        "trust_remote_code": True,
        "gpu_memory_utilization": 0.85,  
        "max_num_seqs": 256, 
        "max_num_batched_tokens": 4096,  
    }
    

    if args.adapter_path:
        if PeftConfig is None:
            raise RuntimeError("peft not installed but --adapter_path specified.")
        
        from transformers import AutoModelForCausalLM
        
        base_name = args.base_model or PeftConfig.from_pretrained(args.adapter_path).base_model_name_or_path
        base_model = AutoModelForCausalLM.from_pretrained(base_name)
        model = PeftModel.from_pretrained(base_model, args.adapter_path)
        merged_model_path = f"./merged_model_{hash(args.adapter_path)}"
        model.merge_and_unload().save_pretrained(merged_model_path)
        vllm_kwargs["model"] = merged_model_path

    llm = LLM(**vllm_kwargs)
    
    tokenizer = AutoTokenizer.from_pretrained(vllm_kwargs["model"])
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    return llm, tokenizer


def vllm_generate(llm, tokenizer, prompts, args):
    if args.eg_k and args.eg_k > 1:
        sampling_params = SamplingParams(
            temperature=args.eg_temp,
            top_p=args.eg_top_p,
            max_tokens=args.max_new_tokens,
            skip_special_tokens=True
        )
    else:
        sampling_params = SamplingParams(
            temperature=0.0,
            top_p=1.0,
            max_tokens=args.max_new_tokens,
            skip_special_tokens=True
        )
    
    chat_prompts = []
    for prompt in prompts:
        chat_text = to_chat(tokenizer, prompt)
        chat_prompts.append(chat_text)
    
    outputs = llm.generate(chat_prompts, sampling_params)
    
    results = []
    for output in outputs:
        generated_text = output.outputs[0].text
        results.append(generated_text)
    
    return results

# ---------------- Main ----------------

def main():
    ap = argparse.ArgumentParser()
    # Pathing
    ap.add_argument("--model_path", type=str, help="Path to a direct/merged model")
    ap.add_argument("--base_model", type=str, default=None, help="Base model to load when using a PEFT adapter")
    ap.add_argument("--adapter_path", type=str, default=None, help="PEFT adapter directory (LoRA/QLoRA/Prefix/Prompt)")
    # Data & output
    ap.add_argument("--test_json", type=str, required=True)
    ap.add_argument("--out_json", type=str, required=True)
    # Decoding
    ap.add_argument("--max_new_tokens", type=int, default=1024)
    ap.add_argument("--subset_size", type=int, default=None, help="Number of samples to test (subset)")
    ap.add_argument("--eg_k", type=int, default=1, help="#samples for EG; 1 = disabled (greedy)")
    ap.add_argument("--eg_temp", type=float, default=0.3, help="Sampling temperature for EG")
    ap.add_argument("--eg_top_p", type=float, default=0.9, help="Top-p for EG sampling")
    ap.add_argument("--input_max_len", type=int, default=2048, help="Truncation length for the prompt")

    args = ap.parse_args()

    if not args.model_path and not args.adapter_path:
        raise ValueError("Provide either --model_path (direct/merged) OR --adapter_path (+ optional --base_model).")

    llm, tokenizer = load_model_and_tokenizer(args)
    
    test_items = load_json(args.test_json)
    if args.subset_size is not None:
        test_items = test_items[:args.subset_size]
        print(f"Testing on subset: {len(test_items)} samples")
    else:
        print(f"Loaded test items: {len(test_items)}")

    all_prompts = []
    for ex in test_items:
        qa = ex.get("qa", {}) or {}
        q = qa.get("question", "")
        table_obj = ex.get("table", ex.get("table_ori", ""))
        prompt = build_prompt(ex.get("pre_text", ""), ex.get("post_text", ""), table_obj, q)
        all_prompts.append(prompt)

    print("vLLM batched inferecing...")
    generated_texts = vllm_generate(llm, tokenizer, all_prompts, args)
    
    preds = []
    total_empties = 0
    total_valids = 0
    total_has_ops = 0
    
    for i, (ex, gen_text) in enumerate(zip(test_items, generated_texts)):
        steps = parse_raw_to_steps(gen_text)
        tokens = steps_to_eval_tokens(steps)
        
        if tokens == ["EOF"]:
            total_empties += 1
        if verify_eval_tokens(tokens):
            total_valids += 1
            if len(tokens) > 1:
                total_has_ops += 1
        
        preds.append({"id": ex.get("id", f"sample_{i}"), "predicted": tokens})
    
    Path(args.out_json).write_text(json.dumps(preds, ensure_ascii=False, indent=2), encoding="utf-8")
    
    total = len(preds)
    print("\n================= STATISTICS =================")
    print(f"Total: {total}")
    print(f"Valid format: {total_valids} ({total_valids*100.0/total:.1f}%)")
    print(f"Has operations: {total_has_ops} ({total_has_ops*100.0/total:.1f}%)")
    print(f"Empty: {total_empties} ({total_empties*100.0/total:.1f}%)")
    print("==============================================")
    print(f"✓ the result has be saved to: {args.out_json}")

if __name__ == "__main__":
    main()