import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm
import re

# === Load base model ===
MODEL_NAME = "/path/to/base_model"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# === File paths ===
SOURCE_FILE = "/path/to/source.txt"
JSONL_FILE = "/path/to/output.jsonl"
TRANSLATION_ONLY_FILE = "/path/to/output.txt"

# === Few-shot examples ===
FEW_SHOT_EXAMPLES = ""


# === Full Prompt ===
def build_prompt(source):
    return f"""You are a professional simultaneous interpreter. You should translate the source sentence into (tgt) language with accuracy and fluency. Finally, please strictly output in JSON format, including:

- "source": Original text
- "translation": The final translation
- "explanation": Explain why this translation is natural and clear

Here are several reference examples:
{FEW_SHOT_EXAMPLES}

Now please translate the following sentences:
"{source}"

Please output in JSON formal:
"""

# === Load model ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16)
model.eval()

# === Inference ===
def infer(prompt, max_new_tokens=256):
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    output = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=0.7,
        top_p=0.9,
        eos_token_id=tokenizer.eos_token_id,
    )
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)

    # Extract the outermost JSON using regular expressions
    try:
        match = re.search(r"\{.*?\}", decoded, re.DOTALL)
        if match:
            json_data = json.loads(match.group())
            return json_data
        else:
            raise ValueError("No JSON object found")
    except Exception:
        return {
            "source": prompt.split("“")[-2] if "“" in prompt else "",
            "translation": "",
            "explanation": "Fail to parse JSON"
        }


with open(SOURCE_FILE, "r", encoding="utf-8") as fin:
    sentences = [line.strip() for line in fin if line.strip()]

with open(JSONL_FILE, "w", encoding="utf-8") as jsonl_out, \
     open(TRANSLATION_ONLY_FILE, "w", encoding="utf-8") as txt_out:

    for source in tqdm(sentences):
        prompt = build_prompt(source)
        result = infer(prompt)
        jsonl_out.write(json.dumps(result, ensure_ascii=False) + "\n")
        txt_out.write(result.get("translation", "") + "\n")