#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import argparse, jsonlines, re, sys
from typing import List
import pandas as pd
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import random
import numpy as np
import torch
import random
import numpy as np
import torch
from transformers import set_seed

def seed_everything(seed: int = 111):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    set_seed(seed)


MAX_INT = sys.maxsize

# ───────────────────────────────────────────────────────────────────────────────

# ───────────────────────────────────────────────────────────────────────────────
def batch_data(items: List[str], bs: int) -> List[List[str]]:
    return [items[i:i + bs] for i in range(0, len(items), bs)] if bs > 0 else [items]

# ───────────────────────────────────────────────────────────────────────────────
# 1. Interpreter
# ───────────────────────────────────────────────────────────────────────────────
class StringFlowInterpreter:
    def __init__(self):
        self.broth, self.ingredients = "", []
        self.stations, self.recipe = {}, []
        self.pc, self.cooking = 0, True

    def prep(self, recipe_str: str):
        for line in recipe_str.strip().splitlines():
            dish = line.strip()
            if not dish:
                continue
            if re.match(r'^[\w]+:$', dish):
                self.stations[dish[:-1]] = len(self.recipe)
            else:
                step, *args = re.findall(r'"[^"]*"|\S+', dish)
                self.recipe.append((step, args))

    def cook(self) -> str:
        self.pc = 0
        while self.cooking and self.pc < len(self.recipe):
            step, args = self.recipe[self.pc]
            self._saute(step, args)
            self.pc += 1
        return self.broth

    def _saute(self, step: str, a: List[str]):
        q = lambda s: s[1:-1] if s.startswith('"') and s.endswith('"') else s

        if step == "pour":   self.broth = q(a[0])
        elif step == "slice": self.ingredients = self.broth.split(q(a[0]))
        elif step == "stir":  self.broth = q(a[0]).join(self.ingredients)
        elif step == "flip":  self.broth = self.broth[::-1]
        elif step == "toss":  self.ingredients.reverse()
        elif step == "season": old, new = map(q, a[:2]); self.broth = self.broth.replace(old, new)
        elif step == "fillet": start, length = map(int, a[:2]); self.broth = self.broth[start:start+length]
        elif step == "flambe": self.broth = self.broth.upper()
        elif step == "simmer": self.broth = self.broth.lower()
        elif step == "garnish": self.broth += q(a[0])
        elif step == "plate":   self.broth = q(a[0]) + self.broth
        elif step == "taste_then":
            if q(a[0]) in self.broth: self.pc = self.stations[a[1]] - 1
        elif step == "move_to": self.pc = self.stations[a[0]] - 1
        elif step == "serve":   self.cooking = False
        else: raise ValueError(f"Unknown step: {step}")

# ───────────────────────────────────────────────────────────────────────────────
# 2. DSL code extraction
# ───────────────────────────────────────────────────────────────────────────────
_TRIPLE_QUOTE = re.compile(r'"""(.*?)"""', re.DOTALL)
_BACKTICK     = re.compile(r"```[^\n]*\n?(.*?)```", re.DOTALL)   # ```python\ncode```

def extract_dsl(text: str) -> str | None:
    # 1) If there is a CoT tag, use only the text after </think>
    if '</think>' in text:
        text = text.split('</think>', 1)[1]



    # 3) If there is no triple-quote block, find all backtick code blocks
    bt_blocks = _BACKTICK.findall(text)
    if bt_blocks:
        return bt_blocks[-1].strip()
    
        # 2) Find all triple-quote blocks
    qt_blocks = _TRIPLE_QUOTE.findall(text)
    if qt_blocks:                       # If there is at least one, use the last one
        return qt_blocks[-1].strip()

    # 4) If there is nothing, return None
    return None
# ───────────────────────────────────────────────────────────────────────────────
# 3. Evaluation routine
# ───────────────────────────────────────────────────────────────────────────────
def run_eval(model_path, data_file, start, end, batch_size,
             tensor_parallel_size, out_csv,seed):

    tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    has_chat = bool(getattr(tok, "chat_template", None))

    def build_prompt(instr: str) -> str:
        if has_chat:
            return tok.apply_chat_template([{"role":"user","content":instr}],
                                           tokenize=False, add_generation_prompt=True)
        return ("Below is an instruction that describes a task. "
                "Write a response that appropriately completes the request.\n\n"
                f"### Instruction:\n{instr}\n\n### Response: Let's think step by step.")

    queries, labels = [], []
    with open(data_file, encoding="utf-8") as f:
        for i, item in enumerate(jsonlines.Reader(f)):
            if i < start: continue
            if i >= end: break
            queries.append(build_prompt(item["query"]))
            labels.append(item["response"].strip())

    llm = LLM(model=model_path, tensor_parallel_size=tensor_parallel_size, seed=seed,
              max_model_len=4096)
    
    stop =  ['</s>', '<|endoftext|>',tok.eos_token]
    params = SamplingParams(temperature=0, repetition_penalty=1.1, max_tokens=4096, stop=stop, seed=seed)

    completions: List[str] = []
    for batch in batch_data(queries, batch_size):
        for o in llm.generate(batch, params):
            completions.append(o.outputs[0].text)

    # --- Length correction -----------------------------------------------------
    if len(completions) != len(queries):
        diff = len(queries) - len(completions)
        print(f"[warn] completions {len(completions)} vs queries {len(queries)}", file=sys.stderr)
        completions.extend(["[no_output]"] * diff) if diff > 0 else None
        completions = completions[:len(queries)]

    # --- SousChef execution ----------------------------------------------------
    preds, flags = [], []
    for label, comp in zip(labels, completions):
        code = extract_dsl(comp)
        if code is None:
            preds.append("[invalid]"); flags.append(0); continue
        chef = StringFlowInterpreter()
        try:
            chef.prep(code)
            pred = chef.cook()
            correct = int(pred.strip() == label.strip())
        except Exception:
            pred, correct = "[error]", 0
        preds.append(pred); flags.append(correct)

    # --- Create DataFrame ------------------------------------------------------
    df = pd.DataFrame({
        "ground_truth": labels,
        "prediction":   preds,
        "correct":      flags,
        "raw_output":   completions
    })
    df.to_csv(out_csv, index=False)

    print("="*50)
    print(f'{len(df)} samples')
    print(f'model: {model_path}')
    print(f'accuracy={df.correct.mean():.4%}')
    print(f"{model_path} CSV saved 👉  {out_csv}")
    print("="*50)

# ───────────────────────────────────────────────────────────────────────────────
# 4. CLI
# ──────────────────────────────────────────────────────────────────────────────
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--model", required=True)
    p.add_argument("--data_file", required=True)
    p.add_argument("--start", type=int, default=0)
    p.add_argument("--end",   type=int, default=MAX_INT)
    p.add_argument("--batch_size", type=int, default=512)
    p.add_argument("--tensor_parallel_size", type=int, default=4)
    p.add_argument("--seed", type=int, default=1111)
    p.add_argument("--out_csv", default="dsl_eval_results.csv")
    a = p.parse_args()

    seed_everything(a.seed)
    
    run_eval(a.model, a.data_file, a.start, a.end,
             a.batch_size, a.tensor_parallel_size, a.out_csv,a.seed)
    
    
    print('seed: ', a.seed)

if __name__ == "__main__":
    main()