import argparse, os, json, logging, re
from typing import Dict, Set

import torch
from accelerate import Accelerator
from datasets import load_dataset
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from tqdm.auto import tqdm

Image.MAX_IMAGE_PIXELS = None

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("main_logger")

SYSTEM_PROMPT = (
    "The user has two images and a textual prompt. "
    "You need to reason carefully and produce an answer with reasoning in "
    "<think>...</think> where you should choose best image."
)

def add_vote_columns(example):
    v1, v2 = example.get("votes_image1", 0), example.get("votes_image2", 0)
    tot = v1 + v2
    winners_dist = 0.5 if tot == 0 else max(v1, v2) / tot
    return {"total_votes": tot, "winners_dist": winners_dist}

def compute_accuracy(out_dir: str, n_proc: int, gt_ds, results_name="results_rank_{}.jsonl"):
    logger.info("Computing accuracy …")
    gt: Dict[int, str] = {}
    for row in gt_ds:
        if row["votes_image1"] != row["votes_image2"]:
            gt[row["original_index"]] = "first" if row["votes_image1"] > row["votes_image2"] else "second"
    if not gt:
        logger.warning("No ground-truth pairs with unequal votes → accuracy 0.0")
        return 0.0, 0, 0
    seen: Set[int] = set()
    correct = total = 0
    for rank in range(n_proc):
        f_path = os.path.join(out_dir, results_name.format(rank))
        if not os.path.exists(f_path):
            logger.warning(f"Results file {f_path} missing.")
            continue
        with open(f_path) as fh:
            for ln, line in enumerate(fh, 1):
                try:
                    obj = json.loads(line)
                    idx = obj.get("original_index")
                    gen = obj.get("generated_text")
                    if idx is None or gen is None or idx not in gt or idx in seen:
                        continue
                    m = re.search(r"<answer>(.*?)</answer>", gen, re.S)
                    if not m:
                        continue
                    pref = json.loads(m.group(1)).get("preferred")
                    if pref not in ("first", "second"):
                        continue
                    seen.add(idx)
                    total += 1
                    if pref == gt[idx]:
                        correct += 1
                except Exception as e:
                    logger.debug(f"Skipping malformed line {ln} in {f_path}: {e}")
    acc = 0.0 if total == 0 else correct / total
    logger.info(f"Accuracy: {correct}/{total} = {acc:.4%}")
    return acc, correct, total

def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt_num", type=int, required=True)
    p.add_argument("--dataset_split", type=str, default="train")
    p.add_argument("--model_path_template",
                   default="/path/to/model/checkpoint-{}")
    p.add_argument("--processor_path",
                   default="model/processor")
    p.add_argument("--output_dir_base", default="output_dir")
    p.add_argument("--output_subdir_template",
                   default="subdir_{}", help="{ckpt}")
    p.add_argument("--max_new_tokens", type=int, default=512)
    p.add_argument("--temperature", type=float, default=1.1)
    p.add_argument("--num_traces", type=int, default=16)
    p.add_argument("--small", action="store_true",
                   help="Filter to strong-agreement subset (votes>15 & winners_dist>0.8)")
    return p.parse_args()

def main():
    args = get_args()
    accelerator = Accelerator()
    subdir = args.output_subdir_template.format(args.ckpt_num)
    subdir += f"_T{args.temperature}"
    if args.small:
        subdir += "_small"
    out_dir = os.path.join(args.output_dir_base, subdir)
    if accelerator.is_main_process:
        logger.info(f"Output dir → {out_dir}")
        os.makedirs(out_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    # ─── load & possibly filter dataset ───────────────────────────────
    dataset = load_dataset("Rapidata/human-style-preferences-images",
                           split=args.dataset_split, trust_remote_code=True)
    dataset = dataset.map(add_vote_columns, num_proc=8)
    dataset = dataset.add_column("original_index", range(len(dataset)))
    if args.small:
        dataset = dataset.filter(lambda ex: ex["total_votes"] > 25
                                           and ex["winners_dist"] > 0.8,
                                 num_proc=8)
    if accelerator.is_main_process:
        logger.info(f"Total samples after filter: {len(dataset)}")
    if len(dataset) == 0:
        logger.warning("Dataset empty after filtering – aborting.")
        return
    dataset = dataset.shard(accelerator.num_processes, accelerator.process_index)
    logger.info(f"Rank {accelerator.process_index}: will process {len(dataset)} examples.")
    processor = AutoProcessor.from_pretrained(args.processor_path,
                                              trust_remote_code=True,
                                              max_pixels=720*28*28)
    m_kwargs = {"torch_dtype": torch.bfloat16, "attn_implementation": "flash_attention_2"}
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                args.model_path_template.format(args.ckpt_num),
                trust_remote_code=True, **m_kwargs)
    model = accelerator.prepare(model)
    model_to_generate = model.module if hasattr(model, "module") else model
    out_file = os.path.join(out_dir, f"results_rank_{accelerator.process_index}.jsonl")
    with open(out_file, "w"):
        pass
    skipped_equal = skipped_data = errors = 0
    processed = 0
    with open(out_file, "a") as fout:
        iterator = tqdm(dataset, desc=f"Rank {accelerator.process_index}",
                        disable=not accelerator.is_main_process, total=len(dataset))
        for row in iterator:
            if row["votes_image1"] == row["votes_image2"]:
                skipped_equal += 1
                continue
            prompt = row["prompt"]
            img1, img2 = row["image1"].resize((512, 512)), row["image2"].resize((512, 512))
            if not (isinstance(prompt, str)
                    and isinstance(img1, Image.Image)
                    and isinstance(img2, Image.Image)):
                skipped_data += 1
                continue
            user_content = [
                {"type": "image"}, {"type": "image"},
                {"type": "text", "text":
                    f"User prompt: {prompt}\n\n"
                    "Which image is better given the prompt? Analyze aesthetics, "
                    "composition, prompt alignment and other factors. "
                    "Provide your reasoning in <think>…</think> tags and the "
                    'final JSON answer in '
                    '<answer>{"preferred":"second"}</answer> or '
                    '{"preferred":"first"}.'}]
            chat = [{"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user",   "content": user_content}]
            prompt_ids = processor.apply_chat_template([chat], add_generation_prompt=True)
            inputs = processor(text=prompt_ids, images=[[img1, img2]],
                               return_tensors="pt", padding=True,
                               padding_side="left").to(accelerator.device)
            try:
                with torch.no_grad():
                    gen_ids = model_to_generate.generate(
                        **inputs, max_new_tokens=args.max_new_tokens,
                        do_sample=True, temperature=args.temperature,
                        num_return_sequences=args.num_traces, top_p=0.9)
                prompt_token_ids_cpu = inputs["input_ids"].cpu()[0]
                prompt_len = len(prompt_token_ids_cpu)
                all_generated_sequences_cpu = gen_ids.cpu()
                sequences_to_decode = []
                for i in range(args.num_traces):
                    completion_tokens = all_generated_sequences_cpu[i, prompt_len:]
                    sequences_to_decode.append(completion_tokens)
                texts = processor.batch_decode(sequences_to_decode, skip_special_tokens=True,
                                               clean_up_tokenization_spaces=False)
                for t_id, txt in enumerate(texts):
                    fout.write(json.dumps({
                        "original_index": row["original_index"],
                        "trace_id":       t_id,
                        "prompt":         prompt,
                        "generated_text": txt,
                        "image1_path":    row["image1_path"],
                        "image2_path":    row["image2_path"],
                        "votes_image1":   row["votes_image1"],
                        "votes_image2":   row["votes_image2"]
                    }) + "\n")
                    processed += 1 
            except Exception as e:
                logger.error(f"Error on idx {row['original_index']}: {e}", exc_info=True)
                errors += 1
    logger.info(f"Rank {accelerator.process_index}: processed {processed} traces "
                f"({processed/args.num_traces:.0f} pairs), "
                f"skipped_equal={skipped_equal}, skipped_data={skipped_data}, errors={errors}")
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        acc, corr, tot = compute_accuracy(out_dir, accelerator.num_processes, dataset)
        summ_path = os.path.join(out_dir, f"accuracy_summary_ckpt_{args.ckpt_num}.json")
        with open(summ_path, "w") as fh:
            json.dump({"ckpt_num": args.ckpt_num, "accuracy": acc,
                       "correct": corr, "total": tot,
                       "num_traces": args.num_traces,
                       "temperature": args.temperature,
                       "small_filter": args.small}, fh, indent=4)
        logger.info(f"Accuracy summary → {summ_path}")

if __name__ == "__main__":
    main()
