import os
import random
import json
import re
import sys
import traceback
import time
import atexit
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any

import torch
from torch.utils.data import Dataset
from PIL import Image
from datasets import load_dataset
import io
import logging

from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    TrainingArguments,
)
from trl import (
    GRPOConfig,
    ModelConfig,
    ScriptArguments,
    TrlParser,
    get_peft_config,
)
from peft import LoraConfig, TaskType, get_peft_model

from grpo_trainer import Qwen2VLGRPOTrainer


logger = logging.getLogger(__name__)

last_processed_images = {
    "chosen": None,
    "rejected": None,
    "chosen_idx": None,
    "rejected_idx": None,
    "sample_idx": None,
    "prompt": None
}

def save_debug_images(error_info=None):
    """Save the last processed images when an error occurs"""
    if last_processed_images["chosen"] is None or last_processed_images["rejected"] is None:
        print("No images to save for debugging")
        return
    
    debug_dir = os.path.join("debug_images", f"error_{int(time.time())}")
    os.makedirs(debug_dir, exist_ok=True)
    
    rank = getattr(torch.distributed, "get_rank", lambda: 0)()
    
    chosen_path = os.path.join(debug_dir, f"rank{rank}_chosen_idx{last_processed_images['chosen_idx']}.png")
    rejected_path = os.path.join(debug_dir, f"rank{rank}_rejected_idx{last_processed_images['rejected_idx']}.png")
    
    last_processed_images["chosen"].save(chosen_path)
    last_processed_images["rejected"].save(rejected_path)
    
    metadata = {
        "sample_idx": last_processed_images["sample_idx"],
        "chosen_idx": last_processed_images["chosen_idx"],
        "rejected_idx": last_processed_images["rejected_idx"],
        "prompt": last_processed_images["prompt"],
        "chosen_size": last_processed_images["chosen"].size,
        "rejected_size": last_processed_images["rejected"].size,
        "error": str(error_info) if error_info else "Unknown error"
    }
    
    with open(os.path.join(debug_dir, f"rank{rank}_metadata.json"), "w") as f:
        json.dump(metadata, f, indent=2)
    
    print(f"Debug images saved to {debug_dir}")

def exception_handler(exc_type, exc_value, exc_tb):
    """Save images when an unhandled exception occurs"""
    print(f"Unhandled exception: {exc_type.__name__}: {exc_value}")
    traceback.print_exception(exc_type, exc_value, exc_tb)
    save_debug_images(f"{exc_type.__name__}: {exc_value}")
    sys.__excepthook__(exc_type, exc_value, exc_tb)

sys.excepthook = exception_handler
atexit.register(save_debug_images, "Program terminated")

@dataclass
class MyScriptArguments(ScriptArguments):

    variant: str = field(
        default="choice",
        metadata={"help": "Which variant to run: 'choice' or 'rating'."},
    )
    dataset_name: str = field(
        default="some dataset",
        metadata={"help": "Dataset name on HF Hub or local path."},
    )
    test_size: float = field(
        default=0.1,
        metadata={"help": "Fraction of the dataset used as test set."},
    )
    training_method: str = field(
        default="baseline", 
        metadata={"help": "Training method: 'baseline', 'naive_listener', or 'contradiction_listener'"}
    )


SYSTEM_PROMPT = (
    "The user has two images and a textual prompt. "
    "You need to reason carefully inside <think>...</think> tags and produce an answer in <answer>...</answer> tags where you should choose best image."
)

CONTRADICTION_SYSTEM_PROMPT = (
    "You are an AI assistant. You will be provided with a reasoning trace from another AI "
    "and its final answer regarding an image preference task. Your task is to determine "
    "if the reasoning contradicts the final answer. For example, if the reasoning "
    "expresses uncertainty or leans towards one image, but the final answer "
    "selects a different one, this is a contradiction. Respond with only 'yes' if there "
    "is a contradiction, or 'no' if there is no contradiction."
)


class PreferenceDataset(Dataset):
    def __init__(
        self,
        split_data,
        variant: str = "choice",
    ):
        """
        Args:
            split_data: The actual dataset split (train or test).
            variant: "choice" or "rating".
        """
        self.data = split_data
        self.variant = variant

    def __len__(self):
        return self.data.num_rows
    def __getitem__(self, idx: int):
        global last_processed_images
        max_retries = 10
        current_idx = idx
        
        for retry in range(max_retries):
            try:
                row = self.data[current_idx]
                images = row['image']
                human_preference = row['human_preference']

                img_chosen_idx = human_preference.index(1)
                img_rejected_idx = human_preference.index(0)

                img_chosen = images[img_chosen_idx]
                img_rejected = images[img_rejected_idx]
                
                last_processed_images["chosen"] = img_chosen
                last_processed_images["rejected"] = img_rejected
                last_processed_images["chosen_idx"] = img_chosen_idx
                last_processed_images["rejected_idx"] = img_rejected_idx
                last_processed_images["sample_idx"] = current_idx
                last_processed_images["prompt"] = row.get('prompt', 'Unknown prompt')

                standard_size = (512, 512)  # Fixed standard size for all images
                img_chosen = img_chosen.resize(standard_size, Image.LANCZOS)
                img_rejected = img_rejected.resize(standard_size, Image.LANCZOS)
                
                if random.random() < 0.5:
                    first_image, second_image = img_chosen, img_rejected
                    correct_label = "first"
                else:
                    first_image, second_image = img_rejected, img_chosen
                    correct_label = "second"


                if self.variant == "choice":

                    user_content = [
                        {"type": "image"},
                        {"type": "image"},
                        {
                            "type": "text",
                            "text": (
                                f"User prompt: {row['prompt']}\n\n"
                                "Which image is better given the prompt? Analyze aesthetics, composition, prompt alignment and other factors. "
                                "Provide your reasoning in <think>...</think> tags, "
                                'and the final JSON answer in <answer>{"preferred":"second"}</answer> or {"preferred":"first"}.\n'
                            ),
                        },
                    ]
                
                conversation_prompt = [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": user_content},
                ]

                return {
                    "id": current_idx,
                    "prompt": conversation_prompt,
                    "first": first_image,
                    "second": second_image,
                    "correct_label": correct_label,
                }
            except Exception as e:
                print(f"Error processing sample {current_idx} (attempt {retry+1}/{max_retries}): {e}")
                save_debug_images(f"Error in __getitem__: {e}")
                
                if retry == max_retries - 1:
                    print(f"Failed after {max_retries} attempts, giving up.")
                    raise
                
                current_idx = (current_idx + 1) % len(self)
                print(f"Trying next index: {current_idx}")


def pick_preferred_reward(prompts, completions, correct_label, **kwargs):
    pattern_answer_tags = r"<answer>(.*?)</answer>"
    pattern_preferred = r'["\']preferred["\']\s*:\s*["\'](first|second)["\']'

    rewards = []
    for comp, corr_label in zip(completions, correct_label):
        if comp and comp[0].get("role") == "assistant":
            text = comp[0]["content"]
        else:
            text = ""

        reward = 0.0
        match_tags = re.search(pattern_answer_tags, text, flags=re.DOTALL)
        if match_tags:
            inside = match_tags.group(1)
            match_pref = re.search(pattern_preferred, inside, flags=re.IGNORECASE)
            if match_pref:
                predicted_label = match_pref.group(1)
                if predicted_label.strip().lower() == corr_label.lower():
                    reward = 0.5
        rewards.append(reward)
    return rewards


def rate_preferred_reward(prompts, completions, correct_label, **kwargs):

    pattern_answer_tags = r"<answer>(.*?)</answer>"
    pattern_image1 = r'["\']first["\']\s*:\s*([\d\.]+)'
    pattern_image2 = r'["\']second["\']\s*:\s*([\d\.]+)'

    rewards = []
    for comp, corr_label in zip(completions, correct_label):
        reward = 0.0
        if comp and comp[0].get("role") == "assistant":
            text = comp[0]["content"]
        else:
            text = ""

        match_tags = re.search(pattern_answer_tags, text, flags=re.DOTALL)
        if match_tags:
            inside = match_tags.group(1)
            m1 = re.search(pattern_image1, inside)
            m2 = re.search(pattern_image2, inside)
            if m1 and m2:
                try:
                    score1 = float(m1.group(1))
                    score2 = float(m2.group(1))

                    if corr_label == "first" and score1 > score2:
                        reward = 1.0
                    elif corr_label == "second" and score2 > score1:
                        reward = 1.0
                except ValueError:
                    pass
        rewards.append(reward)
    return rewards



import json
import re

_RE_TAG_BLOCK = re.compile(
    r"<think>(?P<think>.*?)</think>.*?<answer>(?P<ans>{.*?})</answer>",
    flags=re.DOTALL | re.IGNORECASE,
)
_PREFERRED_RE = re.compile(
    r'"preferred"\s*:\s*"(?P<choice>first|second)"', flags=re.IGNORECASE
)

def format_reward(completions, **kwargs):
    rewards = []

    for comp in completions:
        txt = (comp[0].get("content") if comp and isinstance(comp, list) else "") or ""

        ok = False
        m = _RE_TAG_BLOCK.search(txt)
        if m:
            json_str = m.group("ans")
            try:
                obj = json.loads(json_str)
            except json.JSONDecodeError:
                obj = None

            if isinstance(obj, dict) and "preferred" in obj:
                if str(obj["preferred"]).lower() in ("first", "second"):
                    ok = True
            elif _PREFERRED_RE.search(json_str):
                ok = True

        rewards.append(1.0 if ok else 0.0)

    return rewards


reward_funcs_registry = {
    "pick": pick_preferred_reward,
    "rate": rate_preferred_reward,
    "format": format_reward,
}

class ReferenceModelReward:

    def __init__(self, ref_model, processor, device, rating_tokens=("first", "second")):
        self.ref_model = ref_model
        self.ref_model.eval()
        self.processor = processor
        self.device = device


        self.rating_tokens = list(rating_tokens)
        self.rating_token_ids = processor.tokenizer.convert_tokens_to_ids(
            self.rating_tokens
        )
        if any(tid == processor.tokenizer.unk_token_id for tid in self.rating_token_ids):
            logger.warning(" one of %s is <unk> for the tokenizer", rating_tokens)

    # --------------------------------------------------------------------- #
    def __call__(self, prompts: list, completions: list, **kwargs) -> list[float]:

        rewards = []
        correct_labels = kwargs["correct_label"]
        first_images  = kwargs["first"]
        second_images = kwargs["second"]

        for p_conv, comp, gt_label, img1, img2 in zip(
            prompts, completions, correct_labels, first_images, second_images
        ):
            if not comp or comp[0].get("role") != "assistant":
                rewards.append(0.0)
                continue
            assistant_text = comp[0]["content"]

            full_conv = p_conv + [{"role": "assistant", "content": assistant_text}]
            full_text = self.processor.apply_chat_template(
                full_conv, add_generation_prompt=False, tokenize=False
            )

            anchor = '<answer>{"preferred":"'
            if anchor in full_text:
                prefix_text = full_text.split(anchor)[0] + anchor
            else:
                last_answer = full_text.rfind("<answer>")
                if last_answer != -1:
                    prefix_text = full_text[:last_answer] + anchor
                else:
                    last_think = full_text.rfind("</think>")
                    cut = last_think + len("</think>") if last_think != -1 else len(full_text)
                    prefix_text = full_text[:cut] + "\n" + anchor

            inputs = self.processor(
                text=[prefix_text],
                images=[[img1, img2]],
                return_tensors="pt",
                padding=True,
                truncation=True,
            ).to(self.device)

            with torch.no_grad():
                logits = self.ref_model(**inputs, use_cache=False).logits[:, -1, :] 

            rating_logits = logits[:, self.rating_token_ids].float()
            probs = torch.softmax(rating_logits, dim=-1)[0]

            try:
                idx = self.rating_tokens.index(gt_label)
                p_correct = probs[idx].item()
            except ValueError:
                p_correct = 0.0

            rewards.append(max(0.0, p_correct - 0.5))

        return rewards


class ContradictionReward:
    def __init__(self, ref_model, processor, device, contradiction_threshold=0.2):
        self.ref_model = ref_model
        if self.ref_model:
            self.ref_model.eval()
        self.processor = processor
        self.device = device
        self.contradiction_threshold = contradiction_threshold

        # "yes" means contradiction, "no" means no contradiction
        self.rating_tokens = ["yes", "no"]
        self.rating_token_ids = processor.tokenizer.convert_tokens_to_ids(self.rating_tokens)

        if processor.tokenizer.unk_token_id in self.rating_token_ids:
            unk_tokens = [
                token
                for token, token_id in zip(self.rating_tokens, self.rating_token_ids)
                if token_id == processor.tokenizer.unk_token_id
            ]
            logger.warning(
                f"ContradictionReward: One or more rating tokens {unk_tokens} are UNK for the tokenizer."
            )
        
        try:
            self.yes_token_idx_in_rating = self.rating_tokens.index("yes")
            self.no_token_idx_in_rating = self.rating_tokens.index("no")
        except ValueError:
            logger.error("ContradictionReward: 'yes' or 'no' not in rating_tokens ('%s'). Critical error.", self.rating_tokens)
            self.yes_token_idx_in_rating = 0 
            self.no_token_idx_in_rating = 1


    def __call__(self, prompts: list, completions: list, **kwargs) -> list[float]:
        if not self.ref_model:
            logger.warning("ContradictionReward called without a reference model. Returning 0 rewards.")
            return [0.0] * len(completions)

        batch_rewards = [0.0] * len(completions)
        valid_inputs_for_contradiction_model = []
        original_indices_for_valid_inputs = []

        for i, comp_list in enumerate(completions):
            if not comp_list or not isinstance(comp_list, list) or len(comp_list) == 0 or not comp_list[0].get("role") == "assistant":
                continue
            
            assistant_text = comp_list[0].get("content", "")
            match = _RE_TAG_BLOCK.search(assistant_text)
            if not match:
                continue

            reasoning_text = match.group("think")
            answer_block_json_str = match.group("ans")
            
            answer_choice = None
            preferred_match = _PREFERRED_RE.search(answer_block_json_str)
            if preferred_match:
                answer_choice = preferred_match.group("choice").lower()
            else:
                try:
                    answer_obj = json.loads(answer_block_json_str)
                    if isinstance(answer_obj, dict) and "preferred" in answer_obj:
                        choice_str = str(answer_obj["preferred"]).lower()
                        if choice_str in ("first", "second"):
                            answer_choice = choice_str
                except json.JSONDecodeError:
                    pass

            if not answer_choice:
                continue

            user_message_for_contradiction_model = (
                f"Reasoning:\n{reasoning_text}\n\n"
                f"Final Answer: The model stated its preference is for the '{answer_choice}' image.\n\n"
                "Does the reasoning contradict the final answer? Respond with only 'yes' or 'no'."
            )

            contradiction_eval_chat = [
                {"role": "system", "content": CONTRADICTION_SYSTEM_PROMPT},
                {"role": "user", "content": user_message_for_contradiction_model}
            ]
            
            try:
                templated_input_list = self.processor.apply_chat_template(
                    [contradiction_eval_chat], 
                    add_generation_prompt=True,
                    tokenize=False 
                )
                if not templated_input_list:
                     raise ValueError("apply_chat_template returned empty list for a valid input")
                prefix_text = templated_input_list[0]
                if prefix_text.strip():
                    valid_inputs_for_contradiction_model.append(prefix_text)
                    original_indices_for_valid_inputs.append(i)
                else:
                    logger.warning(f"Empty prefix_text for sample {i} after chat templating for ContradictionReward.")
            except Exception as e:
                logger.error(f"Error applying chat template for ContradictionReward sample {i}: {e}", exc_info=True)
                continue
        
        if not valid_inputs_for_contradiction_model:
            return batch_rewards

        try:
            inputs = self.processor(
                text=valid_inputs_for_contradiction_model,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.processor.tokenizer.model_max_length if hasattr(self.processor.tokenizer, 'model_max_length') else 2048
            ).to(self.device)

            with torch.no_grad():
                all_logits = self.ref_model(**inputs, use_cache=False).logits
                last_token_logits = all_logits[:, -1, :]

            rating_specific_logits = last_token_logits[:, self.rating_token_ids].float()
            all_probs = torch.softmax(rating_specific_logits, dim=-1)

            for i, original_idx in enumerate(original_indices_for_valid_inputs):
                probs_for_sample = all_probs[i]
                prob_contradiction = probs_for_sample[self.yes_token_idx_in_rating].item()
                prob_no_contradiction = probs_for_sample[self.no_token_idx_in_rating].item()

                if prob_contradiction > self.contradiction_threshold:
                    batch_rewards[original_idx] = 0.0
                else:
                    batch_rewards[original_idx] = prob_no_contradiction
            
        except Exception as e:
            logger.error(f"Error during batched ContradictionReward model inference or logit processing: {e}", exc_info=True)


        return batch_rewards

def main(script_args, training_args, model_args):

    sys.stdout.flush()

    ds_all = load_dataset("ymhao/HPDv2", trust_remote_code=True)
    if "train" in ds_all:
        raw_ds = ds_all["train"]
    else:
        raw_ds = ds_all[list(ds_all.keys())[0]]
    ds_shuffled = raw_ds.shuffle(seed=42)
    ds_split = ds_shuffled.train_test_split(test_size=script_args.test_size)
    train_ds = ds_split["train"]
    test_ds = ds_split["test"]
    print("Train size:", train_ds.num_rows, "Test size:", test_ds.num_rows)

    train_dataset = PreferenceDataset(split_data=train_ds, variant=script_args.variant)

    if script_args.variant == "choice":
        chosen_reward_keys = ["pick", "format"]
    else:
        chosen_reward_keys = ["rate", "format"] # Or adjust if rating isn't used
    reward_funcs = [reward_funcs_registry[k] for k in chosen_reward_keys]
    print("Will use standard reward funcs:", chosen_reward_keys)

    peft_config = get_peft_config(model_args)
    if peft_config:
        print(f"Using PEFT with LoRA. Config: {peft_config}")

    trainer = Qwen2VLGRPOTrainer(
        model=model_args.model_name_or_path,
        reward_funcs=[],
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=None,
        peft_config=peft_config,
        attn_implementation='flash_attention_2',
        max_pixels=720 * 28 * 28, # to save compute and memory
        min_pixels=3136,

    )

    ref_model_for_reward = None
    if trainer.ref_model is not None:
         ref_model_for_reward = trainer.ref_model
         print("Using trainer.ref_model for ReferenceModelReward.")
    elif peft_config is not None:
         logger.warning("PEFT is enabled but trainer.ref_model is None. ReferenceModelReward might not work correctly without explicit base model access.")

    if script_args.training_method in ["naive_listener", "contradiction_listener"] and ref_model_for_reward:
        if script_args.training_method == "naive_listener":
            print("Using naive_listener training method with ReferenceModelReward")
            naive_listener = ReferenceModelReward(
                ref_model=ref_model_for_reward,
                processor=trainer.processing_class,
                device=trainer.accelerator.device,
            )
            
            def masked_naive_listener(prompts, completions, **kwargs):
                fmt_rewards = format_reward(prompts=prompts, completions=completions, **kwargs)
                naive_rewards = naive_listener(prompts=prompts, completions=completions, **kwargs)
                return [r if f > 0 else 0.0 for f, r in zip(fmt_rewards, naive_rewards)]
                
            reward_funcs.append(masked_naive_listener)
            
        elif script_args.training_method == "contradiction_listener":
            print("Using contradiction_listener training method with ContradictionReward")
            contradiction_rew = ContradictionReward(
                ref_model=ref_model_for_reward,
                processor=trainer.processing_class,
                device=trainer.accelerator.device,
                contradiction_threshold=0.2
            )
            
            def masked_contradiction_reward(prompts, completions, **kwargs):
                fmt_rewards = format_reward(prompts=prompts, completions=completions, **kwargs)
                contr_rewards = contradiction_rew(prompts=prompts, completions=completions, **kwargs)
                return [r_contr if f_fmt > 0 else 0.0 for f_fmt, r_contr in zip(fmt_rewards, contr_rewards)]
                
            reward_funcs.append(masked_contradiction_reward)
    else:
        print(f"Using baseline training method (no additional listeners)")

    trainer.reward_funcs = reward_funcs
    trainer.reward_processing_classes = [None] * len(reward_funcs)


    trainer.train()

    if hasattr(model_args, "lora_r") and model_args.lora_r:
         trainer.model.save_pretrained(training_args.output_dir)
         print(f"Saved LoRA adapters to {training_args.output_dir}")
    else:
         trainer.save_model(training_args.output_dir)

    if training_args.push_to_hub:
         trainer.push_to_hub(dataset_name=script_args.dataset_name)



if __name__ == "__main__":
    from trl import TrlParser
    parser = TrlParser((MyScriptArguments, GRPOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)
