
import logging
import os
import re
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional
import pathlib

from PIL import Image
from torch.utils.data import Dataset

from math_verify import parse, verify
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
from open_r1.vlm_modules import *
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
from transformers import TrainingArguments
import yaml
import json
import random
import math

import whisper
import librosa
from decord import VideoReader, cpu, AudioReader
import numpy as np

import torch
from typing import Tuple
import copy
from qwen_omni_utils import process_mm_info
import av

from open_r1.prompts import AFFECT_SYSTEM_PROMPT

def check_if_video_has_audio(video_path):
    try:
        container = av.open(video_path)
        audio_streams = [stream for stream in container.streams if stream.type == "audio"]
        if not audio_streams:
            return False
        return True
    except:
        return False

logger = logging.getLogger(__name__)

@dataclass
class GRPOScriptArguments(ScriptArguments):
    
    reward_funcs: list[str] = field(
        default_factory=lambda: ["format", "accuracy", "context", "reasoning"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
    )
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
    )
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
    )
    max_anyres_num: Optional[int] = field(
        default=12,
        metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
    )
    image_root: Optional[str] = field(
        default=None,
        metadata={"help": "Root directory of the image"},
    )
    use_audio_in_video: Optional[bool] = field(
        default=False,
        metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
    )
    
    papo_enabled: bool = field(
        default=False,
        metadata={"help": "Enable PAPO (Perception-Aware Policy Optimization)"},
    )
    papo_version: str = field(
        default="v2",
        metadata={"help": "PAPO version: v0 (mask both), v1 (mask separately), v2 (matrix-guided)"},
    )
    papo_mask_ratio: float = field(
        default=0.6,
        metadata={"help": "PAPO mask ratio (original PAPO uses 0.6 / 60% blackening)"},
    )
    papo_kl_coef: float = field(
        default=0.001,
        metadata={"help": "PAPO KL divergence coefficient (original PAPO: 1e-3)"},
    )
    papo_entropy_coef: float = field(
        default=0.0002,
        metadata={"help": "PAPO entropy regularization coefficient (original PAPO: 0.03)"},
    )
    papo_use_noise: bool = field(
        default=False,
        metadata={"help": "Use noise instead of zeros for masking (original PAPO uses zeros/black)"},
    )
    papo_routing_threshold: float = field(
        default=0.5,
        metadata={"help": "Routing threshold for V2 modality assignment"},
    )
    papo_kl_penalty: str = field(
        default="kl",
        metadata={"help": "PAPO KL penalty type: 'kl' (standard), 'low_var_kl' (low variance), 'abs', 'mse'"},
    )
    
    format_gate_all: Optional[bool] = field(
        default=False,
        metadata={"help": "When True, if format reward fails, zero out all rewards for that sample."},
    )
    disable_stage2_rewards: Optional[bool] = field(
        default=False,
        metadata={"help": "When True, zero out format/perception/coherence stage2 rewards while keeping PAPO V2 routing."},
    )
    
    logit_reward_scale_method: Optional[str] = field(
        default="tanh",
        metadata={"help": "Logit reward scaling method: 'tanh' (smooth, [-1,1]) or 'clip' (hard, [-2,2])"},
    )
    
    logit_reward_use_neg_contrast: str = field(
        default="true",
        metadata={"help": "Whether to use neg contrast in logit reward: 'true'=(S_GT_with-S_GT_no)-(S_Neg_with-S_Neg_no), 'false'=S_GT_with-S_GT_no"},
    )

@dataclass
class GRPOModelConfig(ModelConfig):
    freeze_vision_modules: bool = False

class LazySupervisedDataset(Dataset):

    TYPE_TEMPLATE = {
        "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
        "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
        "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
        "free-form": " Please provide your text answer within the <answer> </answer> tags.",
        "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
        "emer_ov": " Please provide the words to describe emotions within the  <answer> </answer> tags.",
        "emer_ov_mc": " Please provide only the single or multiple option letter (e.g., A for single option or A,E for multi option, etc.) within the <answer> </answer> tags.",
        "judge": " Please answer Yes or No within the <answer> </answer> tags.",

    }

    def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str):
        super(LazySupervisedDataset, self).__init__()
        self.script_args = script_args
        self.list_data_dict = []
        self.question_template = question_template
        self.use_audio_in_video = script_args.use_audio_in_video

        if data_path.endswith(".yaml"):
            with open(data_path, "r") as file:
                yaml_data = yaml.safe_load(file)
                datasets = yaml_data.get("datasets")
                
                for data in datasets:
                    json_path = data.get("json_path")
                    sampling_strategy = data.get("sampling_strategy", "all")
                    sampling_number = None

                    if json_path.endswith(".jsonl"):
                        cur_data_dict = []
                        with open(json_path, "r") as json_file:
                            for line in json_file:
                                cur_data_dict.append(json.loads(line.strip()))
                    elif json_path.endswith(".json"):
                        with open(json_path, "r") as json_file:
                            cur_data_dict = json.load(json_file)
                    else:
                        raise ValueError(f"Unsupported file type: {json_path}")

                    if ":" in sampling_strategy:
                        sampling_strategy, sampling_number = sampling_strategy.split(":")
                        if "%" in sampling_number:
                            sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
                        else:
                            sampling_number = int(sampling_number)

                    if sampling_strategy == "first" and sampling_number is not None:
                        cur_data_dict = cur_data_dict[:sampling_number]
                    elif sampling_strategy == "end" and sampling_number is not None:
                        cur_data_dict = cur_data_dict[-sampling_number:]
                    elif sampling_strategy == "random" and sampling_number is not None:
                        random.shuffle(cur_data_dict)
                        cur_data_dict = cur_data_dict[:sampling_number]

                    if data.get("data_root", None):
                        for each in cur_data_dict:
                            if "path" in each:
                                if isinstance(each["path"], str):
                                    each["path"] = os.path.join(data["data_root"], each["path"])
                                elif isinstance(each["path"], dict):
                                    for k in each["path"].keys():
                                        each["path"][k] = os.path.join(data["data_root"], each["path"][k])
                    print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
                    self.list_data_dict.extend(cur_data_dict)
        else:
            if data_path.endswith(".jsonl"):
                cur_data_dict = []
                with open(data_path, "r") as json_file:
                    for line in json_file:
                        cur_data_dict.append(json.loads(line.strip()))
            elif data_path.endswith(".json"):
                with open(data_path, "r") as json_file:
                    cur_data_dict = json.load(json_file)
            self.list_data_dict = cur_data_dict

        self.mel_size = 128
        self.frames_upbound = 16

    def __len__(self):
        return len(self.list_data_dict)

    def _make_conversation_image_and_video(self, example, use_audio_in_video=False):
        if "problem" not in example or not example["problem"]:
            example["problem"] = (
                "As an expert in the field of emotions, please focus on the facial expressions, body movements, tone, "
                "subtitle content, etc., in the video to discern clues related to the emotions of the individual. "
                "Please provide a detailed description and ultimately predict the emotional state of the individual in the video."
            )
        if "problem_type" not in example:
            example["problem_type"] = "emer_ov"
        if "data_type" not in example:
            example["data_type"] = "video"

        if example["problem_type"] == "multiple choice" or example["problem_type"] == "emer_ov_mc":
            question = example["problem"] + " Options:\n"
            for op in example.get("options", []):
                question += op + "\n"
        else:
            question = example["problem"]

        subtitle = example.get("subtitle")
        subtitle_prompt = ""
        if isinstance(subtitle, str) and subtitle.strip():
            subtitle_prompt = f"\nThe subtitle of this video is: <Subtitle>{subtitle.strip()}</Subtitle>."

        text_prompt = f"{subtitle_prompt}\n{question}\n" + self.TYPE_TEMPLATE[example["problem_type"]]

        if use_audio_in_video:
            if isinstance(example["path"], str):
                has_separate_audio = "audio_path" in example and example["audio_path"]
                if has_separate_audio:
                    audio_source = example["audio_path"]
                    msg = [
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": example["data_type"],
                                    example["data_type"]: example["path"],
                                },
                                {"type": "audio", "audio": audio_source},
                                {
                                    "type": "text",
                                    "text": f"Here is a {example['data_type']}, with the audio from the video.\n" + text_prompt,
                                },
                            ],
                        }
                    ]
                else:
                    video_audio_avaliable = (
                        check_if_video_has_audio(example["path"]) and example["data_type"] == "video"
                    )
                    if video_audio_avaliable:
                        msg = [
                            {
                                "role": "user",
                                "content": [
                                    {
                                        "type": example["data_type"],
                                        example["data_type"]: example["path"],
                                    },
                                    {"type": "audio", "audio": example["path"]},
                                    {
                                        "type": "text",
                                        "text": f"Here is a {example['data_type']}, with the audio from the video.\n" + text_prompt,
                                    },
                                ],
                            }
                        ]
                    else:
                        msg = [
                            {
                                "role": "user",
                                "content": [
                                    {
                                        "type": example["data_type"],
                                        example["data_type"]: example["path"],
                                    },
                                    {
                                        "type": "text",
                                        "text": f"Here is the {example['data_type']}, and there is no audio information, you don't need to process the audio.\n"
                                        + text_prompt,
                                    },
                                ],
                            }
                        ]
            else:
                msg = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "image", "image": example["path"]["image"]},
                            {"type": "audio", "audio": example["path"]["audio"]},
                            {
                                "type": "text",
                                "text": f"Here is the image, with the corresponding audio.\n" + text_prompt,
                            },
                        ],
                    }
                ]
        else:
            msg = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": example["data_type"],
                            example["data_type"]: example["path"],
                        },
                        {"type": "text", "text": text_prompt},
                    ],
                }
            ]

        msg.insert(
            0,
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": AFFECT_SYSTEM_PROMPT,
                    }
                ],
            },
        )

        return msg

    def __getitem__(self, i):
        
        num_base_retries = 3
        import traceback

        try:
            return self._get_item(i)
        except Exception as e:
            print(i)
            traceback.print_exc()

        for attempt_idx in range(num_base_retries):
            try:
                sample_idx = random.choice(range(len(self)))
                sample = self._get_item(sample_idx)
                return sample
            except Exception as e:
                
                traceback.print_exc()
                print(f'[try other #{attempt_idx}] Failed to fetch sample {sample_idx}. Exception:', e)
                pass

    def _get_item(self, i):
        source = self.list_data_dict[i]

        if "path" in source:
            conversation = self._make_conversation_image_and_video(source, use_audio_in_video=self.use_audio_in_video)
            problem_type = source.get("problem_type", "emer_ov")  
            
            audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)

        openset = source.get("openset")
        solution = source.get("solution")

        return {
            "images": images,
            "audios": audios,
            "videos": videos,
            "conversation": conversation,
            "prompt": conversation,
            "openset": openset,  
            "solution": solution,  
            "problem_type": problem_type,
            
            "use_audio_in_video": False,
            
            "path": source.get("path"),
            "extracted_clues": source.get("extracted_clues"),
        }

def get_vlm_module(model_name_or_path):
    
    return QwenOmniModule

def main(script_args, training_args, model_args):
    
    try:
        seed = training_args.data_seed if getattr(training_args, "data_seed", None) is not None else training_args.seed
        random.seed(seed)
    except Exception:
        pass
    
    vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
    print("using vlm module:", vlm_module_cls.__name__)

    reward_funcs_registry = {
        "accuracy": vlm_module_cls.accuracy_reward,
        "format": vlm_module_cls.format_reward,
        "reasoning": vlm_module_cls.patial_reasoning_reward,
        "context": vlm_module_cls.patial_context_reward
    }
    
    reward_funcs = []
    reward_weights = []
    logit_reward_config = {}  
    
    raw_weights = training_args.reward_weights if training_args.reward_weights else [1.0] * len(script_args.reward_funcs)
    
    for i, func_name in enumerate(script_args.reward_funcs):
        weight = raw_weights[i] if i < len(raw_weights) else 1.0
        
        if func_name.startswith('logit_reward.'):
            component = func_name.split('.')[1]  
            logit_reward_config[component] = weight
            print(f"[GRPO] Registered logit reward component: {component} with weight {weight}")
        elif func_name in reward_funcs_registry:
            reward_funcs.append(reward_funcs_registry[func_name])
            reward_weights.append(weight)
        elif "." in func_name:
            
            import importlib
            module_name, function_name = func_name.rsplit(".", 1)
            try:
                module = importlib.import_module(module_name)
                reward_func = getattr(module, function_name)
                reward_funcs.append(reward_func)
                reward_weights.append(weight)
                print(f"[GRPO] Successfully loaded custom reward function: {func_name}")
            except Exception as e:
                raise ValueError(f"Failed to load reward function '{func_name}': {e}")
        else:
             raise ValueError(f"Reward function '{func_name}' not found in registry or as a module path.")
    
    if logit_reward_config:
        logit_reward_config['scale_method'] = script_args.logit_reward_scale_method
    training_args.logit_reward_config = logit_reward_config
    
    training_args.reward_weights = reward_weights
    
    training_args.format_gate_all = script_args.format_gate_all
    training_args.disable_stage2_rewards = script_args.disable_stage2_rewards
    
    papo_config = {
        "enabled": script_args.papo_enabled,
        "version": script_args.papo_version,
        "mask_ratio": script_args.papo_mask_ratio,
        "use_noise": script_args.papo_use_noise,
        "kl_coef": script_args.papo_kl_coef,
        "entropy_coef": script_args.papo_entropy_coef,
        "kl_penalty": script_args.papo_kl_penalty,  
        "routing_threshold": script_args.papo_routing_threshold,
    }
    training_args.papo_config = papo_config
    
    training_args.logit_reward_use_neg_contrast = script_args.logit_reward_use_neg_contrast
    
    print(f"[GRPO] reward_funcs: {script_args.reward_funcs}")
    print(f"[GRPO] Regular reward_funcs: {reward_funcs}, weights: {reward_weights}")
    print(f"[GRPO] Logit reward config: {logit_reward_config}")
    print(f"[GRPO] Logit reward use_neg_contrast: {script_args.logit_reward_use_neg_contrast}")
    print(f"[GRPO] PAPO config: {papo_config}")
    
    try:
        import affect_r1.affect_reward as affect_reward
        affect_reward.set_score_threshold(script_args.papo_routing_threshold)
    except ImportError:
        print("[GRPO] Warning: Could not import affect_reward to set score threshold")
    
    dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))

    trainer = VLMGRPOTrainer(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        vlm_module=vlm_module_cls(),
        train_dataset=dataset,
        eval_dataset=None,
        peft_config=get_peft_config(model_args),
        freeze_vision_modules=model_args.freeze_vision_modules,
        attn_implementation=model_args.attn_implementation,
        max_pixels=script_args.max_pixels,
        min_pixels=script_args.min_pixels,
        max_anyres_num=script_args.max_anyres_num,
        torch_dtype=model_args.torch_dtype,
    )

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

    trainer.save_model(training_args.output_dir)
    
if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)
