import argparse
import os
import json
import pandas as pd
import torch
import numpy as np
from accelerate import Accelerator
from datasets import load_dataset
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import logging
from tqdm import tqdm
import torch.distributed as dist

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

SYSTEM_PROMPT = (
    "The user has two images and a textual prompt. "
    "You need to reason carefully and produce an answer with reasoning in <think>...</think> where you should choose best image."
)

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt_num", type=int, required=True)
    p.add_argument("--dataset_split", type=str, default="train")
    p.add_argument("--model_path_template", type=str,
                   default="MODEL_CHECKPOINT_PATH_TEMPLATE/checkpoint-{}")
    p.add_argument("--processor_path", type=str,
                   default="Qwen/Qwen2.5-VL-7B-Instruct")
    p.add_argument("--input_dir_base", type=str, default="inferenced_reasoning")
    p.add_argument("--input_subdir_template", type=str, default="rapidata_ood_hpsv_{}")
    p.add_argument("--output_filename_template", type=str,
                   default="rapidata_ood_hpsv_{}_scores.csv")
    p.add_argument("--use_flash_attention_2", action="store_true")
    p.add_argument("--use_instruct_model", action="store_true")
    return p.parse_args()

def safe_load_image(path: str):
    if not isinstance(path, str):
        return None
    img = Image.open(path)
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img

def main():
    args = parse_args()
    accel = Accelerator()
    model_path = args.model_path_template.format(args.ckpt_num)
    input_dir = os.path.join(
        args.input_dir_base,
        args.input_subdir_template.format(args.ckpt_num)
    )
    if args.use_instruct_model:
        args.output_filename_template = args.output_filename_template.replace(".csv","_instruct.csv")
    output_csv = os.path.join(input_dir, args.output_filename_template.format(args.ckpt_num))
    if accel.is_main_process:
        os.makedirs(input_dir, exist_ok=True)
    accel.wait_for_everyone()
    df_combined = df_to_process = None
    if accel.is_main_process:
        combined = []
        for fn in sorted(os.listdir(input_dir)):
            if fn.startswith("results_rank_") and fn.endswith(".jsonl"):
                with open(os.path.join(input_dir,fn)) as f:
                    for line in f:
                        try:
                            combined.append(json.loads(line))
                        except:
                            pass
        if combined:
            df_combined = pd.DataFrame(combined)
            err_mask = df_combined["generated_text"].isna() & df_combined.get("error",pd.Series()).notna()
            df_to_process = df_combined[~err_mask].reset_index(drop=True)
        else:
            df_combined = pd.DataFrame()
            df_to_process = pd.DataFrame()
    objs = [df_to_process] if accel.is_main_process else [None]
    if accel.num_processes>1:
        dist.broadcast_object_list(objs, src=0)
    df_full = objs[0]
    if df_full is None or df_full.empty:
        if accel.is_main_process:
            logger.error("No reasoning traces to score.")
        accel.wait_for_everyone()
        return
    df_subset = df_full.iloc[accel.process_index::accel.num_processes]
    logger.info(f"Rank {accel.process_index}: scoring {len(df_subset)} items.")
    ds = load_dataset("Rapidata/human-style-preferences-images",
                      split=args.dataset_split)
    ds = ds.add_column("original_index", list(range(len(ds))))
    processor = AutoProcessor.from_pretrained(
        args.processor_path, max_pixels=720*28*28
    )
    rating_tokens = ["first","second"]
    rating_token_ids = processor.tokenizer.convert_tokens_to_ids(rating_tokens)
    model_ckpt = "Qwen/Qwen2.5-VL-7B-Instruct" if args.use_instruct_model else model_path
    mk = {"torch_dtype":torch.bfloat16}
    mk["attn_implementation"] = "flash_attention_2"
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_ckpt, **mk)
    model = accel.prepare(model)
    results = []
    for _, row in tqdm(df_subset.iterrows(), total=len(df_subset),
                       desc=f"Rank {accel.process_index} Scoring"):
        idx = row["original_index"]
        gen_text = row.get("generated_text","")
        score = np.nan
        prompt = row["prompt"]
        entry = ds[idx]
        img1 = entry["image1"]
        img2 = entry["image2"]
        user_ct = [
            {"type":"image"},{"type":"image"},
            {"type":"text", "text":
                (f"User prompt: {prompt}\n\n"
                "Which image is better given the prompt? "
                "Provide reasoning in <think>...</think>, "
                "and final answer in <answer>{\"preferred\":\"second\"}</answer> "
                "or {\"preferred\":\"first\"}.\n")}
        ]
        conv = [
            {"role":"system","content":SYSTEM_PROMPT},
            {"role":"user","content":user_ct},
            {"role":"assistant","content":gen_text}
        ]
        full = processor.apply_chat_template(
            [conv], add_generation_prompt=False, tokenize=False
        )
        p1 = '</think>\n<answer>{"preferred":"'
        p2 = '</think>\n\n<answer>{"preferred":"'
        if   p1 in full:
            text_sc = full.split(p1)[0] + p1
        elif p2 in full:
            text_sc = full.split(p2)[0] + p2
        else:
            last = full[0].rfind("<answer>")
            if last>=0:
                text_sc = full[:last] + '<answer>{"preferred":"'
            else:
                raise ValueError("no <answer> prefix found")
        inputs = processor(
            text=[text_sc],
            images=[[img1,img2]],
            return_tensors="pt",
            padding=True, truncation=True
        ).to(accel.device)
        with torch.no_grad():
            out = model(**inputs)
            logits = out.logits[:, -1, :]
            sub = logits[:, rating_token_ids].float()
            probs = torch.softmax(sub, dim=-1)
        score = probs[0, rating_tokens.index("first")].item()
        results.append({"original_index":idx, "score":score})
    accel.wait_for_everyone()
    container = [None]*accel.num_processes if accel.num_processes>1 else None
    if accel.num_processes>1:
        dist.all_gather_object(container, results)
    else:
        container = [results]
    if accel.is_main_process:
        all_res = []
        for sub in container:
            all_res.extend(sub or [])
        df_scores = pd.DataFrame(all_res).drop_duplicates("original_index",keep="last")
        merged = df_combined.merge(df_scores,
                                  on="original_index",
                                  how="left")
        merged.to_csv(output_csv, index=False)
        logger.info(f"Saved final CSV with soft scores to {output_csv}")

if __name__ == "__main__":
    main() 