import argparse
import os
import json
import pandas as pd
import torch
import numpy as np
from accelerate import Accelerator
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from tqdm import tqdm
import math
import torch.distributed as dist

SYSTEM_PROMPT = (
    "The user has two images and a textual prompt. "
    "You need to reason carefully and produce an answer with reasoning in <think>...</think> where you should choose best image."
)

def parse_args():
    parser = argparse.ArgumentParser(description="Run distributed soft score calculation.")
    parser.add_argument(
        "--ckpt_num",
        type=int,
        required=True,
        help="Checkpoint number used for the initial inference."
    )
    parser.add_argument(
        "--model_path_template",
        type=str,
        default='/path/to/checkpoints/checkpoint-{}',
        help="Template path for the model checkpoint directory. {} will be replaced by ckpt_num."
    )
    parser.add_argument(
        "--processor_path",
        type=str,
        default="Qwen/Qwen2.5-VL-7B-Instruct",
        help="Path or name of the model processor."
    )
    parser.add_argument(
        "--input_dir_base",
        type=str,
        default="/path/to/inference_results/",
        help="Base directory where inference results were saved."
    )
    parser.add_argument(
        "--input_subdir_template",
        type=str,
        default="input_subdir_{}",
        help="Subdirectory template within input_dir_base. {} will be replaced by ckpt_num."
    )
    parser.add_argument(
        "--output_filename_template",
        type=str,
        default="output_scores_{}.csv",
        help="Filename template for the final output CSV file within the input subdirectory. {} will be replaced by ckpt_num."
    )
    parser.add_argument(
        "--use_flash_attention_2",
        action="store_true",
        help="Enable Flash Attention 2 if available.",
    )
    parser.add_argument(
        "--use_instruct_model",
        action="store_true",
        help="Use the instruct model for scoring.",
    )
    args = parser.parse_args()
    return args

def safe_load_image(image_path):
    if not image_path or not isinstance(image_path, str):
        return None
    try:
        img = Image.open(image_path)
        if img.mode != 'RGB':
            img = img.convert('RGB')
        return img
    except FileNotFoundError:
        return None
    except Exception as e:
        return None

def extract_images_from_conversation(conversation_prompt, image_objects):
    images_in_order = []
    image_idx = 0
    for turn in conversation_prompt:
        if isinstance(turn.get("content"), list):
            for item in turn["content"]:
                if isinstance(item, dict) and item.get("type") == "image":
                    if image_idx < len(image_objects):
                        images_in_order.append(image_objects[image_idx])
                        image_idx += 1
                    else:
                        return None
    return images_in_order

def main():
    args = parse_args()
    accelerator = Accelerator()
    model_path = args.model_path_template.format(args.ckpt_num)
    input_dir = os.path.join(args.input_dir_base, args.input_subdir_template.format(args.ckpt_num))
    if args.use_instruct_model:
        args.output_filename_template = args.output_filename_template.replace('.csv', '_instruct.csv')
    output_file = os.path.join(input_dir, args.output_filename_template.format(args.ckpt_num))
    df_to_process = None
    df_combined = None
    if accelerator.is_main_process:
        jsonl_files = [f for f in os.listdir(input_dir) if f.startswith("results_rank_") and f.endswith(".jsonl")]
        if not jsonl_files:
            combined_data = []
        else:
            combined_data = []
            for filename in tqdm(jsonl_files, desc="Loading .jsonl files"):
                filepath = os.path.join(input_dir, filename)
                try:
                    with open(filepath, 'r') as f:
                        for line in f:
                            try:
                                data = json.loads(line.strip())
                                combined_data.append(data)
                            except json.JSONDecodeError as e:
                                pass
                except Exception as e:
                    pass
        if not combined_data:
            df_combined = pd.DataFrame()
            df_to_process = pd.DataFrame()
        else:
            df_combined = pd.DataFrame(combined_data)
            error_mask = df_combined['generated_text'].isna() & df_combined.get('error', pd.Series(dtype=object)).notna()
            df_to_process = df_combined[~error_mask].copy()
            required_cols = ['original_index', 'prompt', 'image_path', 'anchor_image_path', 'generated_text']
            if not all(col in df_to_process.columns for col in required_cols):
                df_to_process = pd.DataFrame()
    object_list_to_broadcast = [df_to_process] if accelerator.is_main_process else [None]
    if accelerator.num_processes > 1:
        dist.broadcast_object_list(object_list_to_broadcast, src=0)
    df_to_process_full = object_list_to_broadcast[0]
    if df_to_process_full is None or df_to_process_full.empty:
        accelerator.wait_for_everyone()
        return
    df_subset = df_to_process_full.iloc[accelerator.process_index::accelerator.num_processes]
    model = None
    try:
        processor = AutoProcessor.from_pretrained(args.processor_path, max_pixels=720*28*28)
        rating_tokens = ["first", "second"]
        rating_token_ids = processor.tokenizer.convert_tokens_to_ids(rating_tokens)
        if any(tid == processor.tokenizer.unk_token_id for tid in rating_token_ids):
            raise ValueError("Rating token(s) unknown to tokenizer.")
    except Exception as e:
        accelerator.wait_for_everyone()
        return
    if len(df_subset) > 0:
        model_kwargs = {
            "torch_dtype": torch.bfloat16,
        }
        if args.use_flash_attention_2:
            if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
                model_kwargs["attn_implementation"] = "flash_attention_2"
        try:
            if args.use_instruct_model:
                model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                    'Qwen/Qwen2.5-VL-7B-Instruct',
                    **model_kwargs
                )
            else:
                model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                    model_path,
                    **model_kwargs
                )
            model = accelerator.prepare(model)
        except Exception as e:
            model = None
    accelerator.wait_for_everyone()
    results_list = []
    if 'processor' in locals() and (model is not None or len(df_subset) == 0):
        if len(df_subset) > 0 and model is not None:
            iterable = tqdm(df_subset.iterrows(), total=len(df_subset), desc=f"Rank {accelerator.process_index} Scoring", disable=not accelerator.is_main_process)
            for _, row_data in iterable:
                original_idx = row_data.get('original_index', 'unknown')
                score = np.nan
                try:
                    generated_text = row_data.get('generated_text')
                    if not generated_text or not isinstance(generated_text, str):
                        results_list.append({'original_index': original_idx, 'score': score})
                        continue
                    compare_image_path = row_data.get('image_path')
                    anchor_image_path = row_data.get('anchor_image_path')
                    prompt_text = row_data.get('prompt')
                    if not compare_image_path or not anchor_image_path or not prompt_text:
                        results_list.append({'original_index': original_idx, 'score': score})
                        continue
                    compare_image = safe_load_image(compare_image_path)
                    anchor_image = safe_load_image(anchor_image_path)
                    if compare_image is None or anchor_image is None:
                        results_list.append({'original_index': original_idx, 'score': score})
                        continue
                    user_content = [
                        {"type": "image"},
                        {"type": "image"},
                        {
                            "type": "text",
                            "text": (
                                f"User prompt: {prompt_text}\n\n"
                                "Which image is better given the prompt? Analyze aesthetics, composition, prompt alignment and other factors. "
                                "Provide your reasoning in <think>...</think> tags, "
                                'and the final JSON answer in <answer>{"preferred":"second"}</answer> or {"preferred":"first"}.\n'
                            ),
                        },
                    ]
                    image_objects_for_prompt = [compare_image, anchor_image]
                    assistant_content = generated_text
                    conversation_prompt = [
                        {"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": user_content},
                        {'role': 'assistant', 'content': assistant_content}
                    ]
                    template_prefix_1 = '</think>\n<answer>{"preferred":"'
                    template_prefix_2 = '</think>\n\n<answer>{"preferred":"'
                    fallback_suffix = '<answer>{"preferred":"'
                    fallback_anchor = "<answer>"
                    full_text = processor.apply_chat_template(conversation_prompt, add_generation_prompt=False, tokenize=False)
                    text_for_scoring = None
                    split_parts_1 = full_text.split(template_prefix_1)
                    if len(split_parts_1) > 1:
                        text_for_scoring = split_parts_1[0] + template_prefix_1
                    else:
                        split_parts_2 = full_text.split(template_prefix_2)
                        if len(split_parts_2) > 1:
                            text_for_scoring = split_parts_2[0] + template_prefix_2
                        else:
                            last_answer_index = full_text.rfind(fallback_anchor)
                            if last_answer_index != -1:
                                truncated_text = full_text[:last_answer_index]
                                text_for_scoring = truncated_text + fallback_suffix
                            else:
                                results_list.append({'original_index': original_idx, 'score': score})
                                continue
                    if text_for_scoring is None:
                        results_list.append({'original_index': original_idx, 'score': score})
                        continue
                    images_for_processor = [image_objects_for_prompt]
                    inputs = processor(
                        text=[text_for_scoring],
                        images=images_for_processor,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                    ).to(accelerator.device)
                    with torch.no_grad():
                        outputs = model(**inputs)
                        last_token_logits = outputs.logits[:, -1, :]
                    rating_logits = last_token_logits[:, rating_token_ids].float()
                    probs = torch.softmax(rating_logits, dim=-1)
                    first_token_index = rating_tokens.index("first")
                    score = probs[0, first_token_index].item()
                    results_list.append({'original_index': original_idx, 'score': score})
                except Exception as e:
                    results_list.append({'original_index': original_idx, 'score': np.nan})
        elif len(df_subset) == 0:
            pass
        else:
            for _, row_data in df_subset.iterrows():
                original_idx = row_data.get('original_index', 'unknown')
                results_list.append({'original_index': original_idx, 'score': np.nan})
    accelerator.wait_for_everyone()
    all_results = None
    try:
        gathered_object_list_container = [None] * accelerator.num_processes
        if accelerator.num_processes > 1:
            dist.all_gather_object(gathered_object_list_container, results_list)
        else:
            gathered_object_list_container = [results_list]
        if accelerator.is_main_process:
            all_results = []
            for lst in gathered_object_list_container:
                if lst is not None:
                    all_results.extend(lst)
    except Exception as e:
        pass
    if accelerator.is_main_process:
        if all_results is None:
            pass
        elif not all_results:
            if df_combined is not None and not df_combined.empty:
                df_combined['score'] = np.nan
                try:
                    df_combined.to_csv(output_file, index=False)
                except Exception as e:
                    pass
        else:
            scores_df = pd.DataFrame(all_results)
            scores_df = scores_df.drop_duplicates(subset=['original_index'], keep='last')
            if df_combined is None or df_combined.empty:
                pass
            else:
                final_df = pd.merge(df_combined, scores_df, on='original_index', how='left')
                try:
                    final_df.to_csv(output_file, index=False)
                except Exception as e:
                    pass
    accelerator.wait_for_everyone()

if __name__ == "__main__":
    main()