import json
import os
import re
import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List

TAG_RX = re.compile(r"<[^>]+>")

def split_trace(trace: str) -> tuple[str, str]:
    think = re.search(r"<think>(.*?)</think>", trace, re.S | re.I)
    answer = re.search(r"<answer>(.*?)</answer>", trace, re.S | re.I)
    reasoning = TAG_RX.sub("", think.group(1)) if think else ""
    final_ans = TAG_RX.sub("", answer.group(1)) if answer else ""
    return reasoning.strip(), final_ans.strip()

SYSTEM_MSG = (
    "You are an expert factual verifier. "
    "Determine whether the model's final answer contradicts its reasoning. "
    "Reply with the single word YES if it contradicts, otherwise NO. Answer only YES or NO."
)

def build_chat_prompt(trace: str) -> str:
    reasoning, final_ans = split_trace(trace)
    user_content = (
        f"Reasoning of the model:\n{reasoning}\n\n"
        f"Its answer:\n{final_ans}\n\n"
        "Does the answer contradict the reasoning?\n"
        "Reply ONLY YES or NO."
    )
    messages = [
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": user_content},
    ]
    return tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

@torch.inference_mode()
def contradiction_flags(
    traces: List[str],
    model,
    tokenizer,
    batch_size: int = 64,
    max_new_tokens: int = 10,
    **gen_kw
) -> List[bool]:
    results: List[bool] = []
    n_batches = (len(traces) + batch_size - 1) // batch_size
    for start in tqdm(range(0, len(traces), batch_size),
                      total=n_batches,
                      desc="Contradiction check",
                      unit="batch"):
        chunk = traces[start : start + batch_size]
        prompts = [build_chat_prompt(t) for t in chunk]
        enc = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(model.device)
        outs = model.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            **gen_kw
        )
        prompt_len = enc["attention_mask"].sum(dim=1)
        for i in range(len(chunk)):
            cont_ids = outs[i, prompt_len[i]:]
            verdict = tokenizer.decode(cont_ids, skip_special_tokens=False).strip()
            if "<|im_start|>assistant" in verdict:
                answer = verdict.split("<|im_start|>assistant")[1].strip().split("<|im_end|>")[0].strip()
            else:
                answer = 'NO'
            results.append(answer == "YES")
        del enc, outs
        torch.cuda.empty_cache()
    return results

if __name__ == "__main__":
    MODEL_NAME = "Qwen/Qwen2.5-14B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype="auto",
        device_map="auto"
    ).eval()
    example_traces = [
        "<think>The first image matches the prompt better …</think>\n<answer>{\"preferred\":\"second\"}</answer>",
        "<think>The second image is clearly superior …</think>\n<answer>{\"preferred\":\"second\"}</answer>",
    ]
    if example_traces:
        flags = contradiction_flags(example_traces, model, tokenizer, batch_size=2)
        print("\nContradictions per row:")
        for i, (trace, flag) in enumerate(zip(example_traces, flags)):
            print(f"Trace {i+1}: {trace.splitlines()[0][:60]}...")
            print(f"Contradicts? {'YES' if flag else 'NO'}")
            print("-" * 60)
        print("\nTotal contradictions:", sum(flags))
        print(f"Proportion of contradictions: {sum(flags) / len(flags):.2%}")
    else:
        print("No traces to process.")
