import json
import re
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# ========== Configuration ==========
FEWSHOT_PATH = ""
EMBEDDING_MODEL_PATH = ""  # Local sentence-transformer model
MODEL_NAME = ""  # Base model like Qwen3-8B
INPUT_PATH = ""  # Source language text
OUTPUT_PATH = ""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_FEWSHOT = 5
MAX_NEW_TOKENS = 256  # Change according to needs

# ========== Loading ==========
print("Loading embedding model")
embedder = SentenceTransformer(EMBEDDING_MODEL_PATH)

print("Loading base model")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.eval()

with open(FEWSHOT_PATH, 'r', encoding='utf-8') as f:
    fewshot_data = json.load(f)

# ========== Clustering ==========
def compute_cluster_centers(fewshot_data):
    centers = {}
    for cluster_name, examples in fewshot_data.items():
        if not examples: continue
        embs = embedder.encode([ex['en'] for ex in examples])
        centers[cluster_name] = embs.mean(axis=0)
    return centers

cluster_centers = compute_cluster_centers(fewshot_data)

# ========== Prompt formulation ==========
def find_best_cluster(sentence):
    query_emb = embedder.encode(sentence)
    best_cluster, best_score = None, -1
    for cluster, center in cluster_centers.items():
        score = util.cos_sim(query_emb, center).item()
        if score > best_score:
            best_cluster, best_score = cluster, score
    return best_cluster

def build_messages(sentence, fewshot_pairs):
    messages = []
    # system
    system_prompt = (
        "You are a professional (e.g., English-to-Chinese) simultaneous interpreter. "
        "Translate the given (src) sentence into (tgt). "
        "Do not output any thinking steps, only output the final result."
        "Output only JSON: {\"source\": \"...\", \"translation\": \"...\"}."
    )
    messages.append({"role": "system", "content": system_prompt})

    # few-shot
    for ex in fewshot_pairs:
        user_c = f"Sentence:\n{ex['src']}"
        assistant_c = json.dumps({"source": ex["src"], "translation": ex["tgt"]}, ensure_ascii=False)
        messages.append({"role": "user", "content": user_c})
        messages.append({"role": "assistant", "content": assistant_c})

    # final user
    messages.append({"role": "user", "content": f"Sentence:\n{sentence}"})
    return messages

def build_prompt_with_template(sentence):
    cluster = find_best_cluster(sentence)
    examples = fewshot_data[cluster][:MAX_FEWSHOT]
    messages = build_messages(sentence, examples)
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True, 
        return_dict=False
    )
    return prompt




# ========== Inference ==========
def extract_last_json_block(s: str):
    # Find the last '}' from the end forward, and then match forward to the balanced '{'
    end = s.rfind('}')
    while end != -1:
        stack = 0
        i = end
        while i >= 0:
            if s[i] == '}':
                stack += 1
            elif s[i] == '{':
                stack -= 1
                if stack == 0:
                    candidate = s[i:end+1]
                    return candidate, i, end+1
            i -= 1
        end = s.rfind('}', 0, end)
    return None, -1, -1

def extract_translation(s: str):
    # Try Jsonl first
    block, _, _ = extract_last_json_block(s)
    if block:
        try:
            obj = json.loads(block)
            if isinstance(obj, dict) and "translation" in obj:
                return obj["translation"].strip(), block
        except Exception:
            pass
    # Catch-all: Directly extract "translation" from the text :"..."
    m = re.search(r'"translation"\s*:\s*"(?P<val>.*?)"', s, flags=re.DOTALL)
    if m:
        val = m.group('val')
        val = val.encode('utf-8', 'ignore').decode('unicode_escape').strip()
        return val, None
    return None, block


def infer(prompt, input_en, idx):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            temperature=0.0,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.05,  
        )

    new_token_ids = output_ids[0, input_ids.shape[1]:]
    new_text = tokenizer.decode(new_token_ids, skip_special_tokens=True)
    new_text = new_text.replace("<|end|>", "").replace("<|begin|>", "").strip()

    print(f"\n================= Number {idx+1} =================")
    print(f"Source sentence: {input_en}")
    print("Model output")
    print(new_text)

    translation, last_json_block = extract_translation(new_text)

    if translation:
        print(f"Successfully extracted the translation: {translation}")
        return translation

    if last_json_block:
        print("The JSON exists but the parsing failed, returning the final JSON string.")
        return last_json_block.strip()

    print("JSON/translation was not found. Return the newly added original output")
    return new_text



# ========== Main ==========
with open(INPUT_PATH, 'r', encoding='utf-8') as infile, open(OUTPUT_PATH, 'w', encoding='utf-8') as outfile:
    for idx, line in enumerate(tqdm(infile.readlines(), desc="Running Inference")):
        sentence = line.strip()
        if not sentence:
            outfile.write("\n")
            continue
        prompt = build_prompt_with_template(sentence)
        result = infer(prompt, sentence, idx)
        outfile.write(result + "\n")
