import numpy as np
import os
import re
import yaml
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset, load_from_disk, Dataset, DatasetDict
from transformers import Qwen2VLForConditionalGeneration

from math_verify import parse, verify
from src.open_r1.trainer import Qwen2VLGRPOTrainer_Video_TG_Our_Ver3 as Qwen2VLGRPOTrainer
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
from src.open_r1.original_qwen_utils import process_vision_info
from tqdm import tqdm
import torch
import json
import random
random.seed(42)

@dataclass
class GRPOScriptArguments(ScriptArguments):
    """
    Script arguments for the GRPO training script.

    Args:
        reward_funcs (`list[str]`):
            List of reward functions. Possible values: 'iou', 'format'.
    """

    reward_funcs: list[str] = field(
        default_factory=lambda: ["iou", "format", "distance", "gt_pos", "pred_pos", "has_neg_prompt", "qa_accuracy", "good_gqa"],
        metadata={"help": "List of reward functions. Possible values: 'iou', 'format'"},
    )
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image"},
    )
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image"},
    )

    train_data_path: str = field(
        default="/",
        metadata={"help": "Path to the training data JSON file."},
    )
    eval_data_path: str = field(
        default="/",
        metadata={"help": "Path to the evaluation data JSON file."},
    )

    video_folder: str = field(
        default="/",  # Replace with your actual video folder path
        metadata={"help": "Path to the folder containing video files."},
    )


def parse_timestamp_output(output_string):
    """Parses timestamp output, similar to the example code. Accepts negative numbers."""
    # 1. Find all <answer>...</answer> blocks.
    answer_matches = re.findall(r"<timestamp>(.*?)</timestamp>", output_string, re.DOTALL)

    if not answer_matches:
        return None  # No <answer> tags found.

    # 2. Use the content of the *last* <answer> block.
    last_answer_content = answer_matches[-1]
    # print('last_answer_content:', last_answer_content)

    # Modified regex to accept negative numbers
    matches = re.findall(r"(-?\d+\.?\d*) (to|and) (-?\d+\.?\d*)", last_answer_content, re.IGNORECASE)
    if not matches:
        return None
    last_match = matches[-1]
    start_time = float(last_match[0])
    end_time = float(last_match[2])
    
    return start_time, end_time



def distance_reward(completions, solution, durations, **kwargs):
    """Reward function that calculates IoU between predicted and ground truth timestamps."""

    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    for prompt, content, sol, duration in zip(kwargs['prompts'], completions, solution, durations):
        reward = 0.0
        parsed_times = parse_timestamp_output(content)
        start_time, end_time = 0, 0
        gt_start, gt_end = sol
        s, e = gt_start, gt_end
        if parsed_times:
            start_time, end_time = parsed_times
            
            if s < e:
                if start_time > end_time:
                    reward = -0.1
                else:
                    dist = abs((s+e)/2 - (start_time+end_time)/2)
                    reward = max(0., 1. - dist / duration)
            else:
                if start_time > end_time:
                    reward = 0.5
                else:
                    reward = 0.

        rewards.append(reward)

        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(f"Prompt: {prompt[0]['content'][0]['text']}\n")
                f.write(f"Content: {content}\n")
                f.write(f"pred second: {str(start_time)}, {str(end_time)}\n")
                f.write(f"gt second: {str(gt_start)}, {str(gt_end)}\n")
                f.write(f"------------- {current_time} Distance reward: {reward} -------------\n") # Modified log message

    return rewards

def iou_timestamp_reward(completions, solution, durations, **kwargs): # Modified reward function name and arguments
    """Reward function that calculates IoU between predicted and ground truth timestamps."""

    # print(completions, solution, durations)
    # contents = [completion[0]["content"] for completion in completions]
    rewards = []
    # print(completions, solution, durations, **kwargs)
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    for prompt, content, sol, duration in zip(kwargs['prompts'], completions, solution, durations): # Added video_durations
        reward = 0.0
        parsed_times = parse_timestamp_output(content)
        start_time, end_time = 0, 0
        gt_start, gt_end = sol
        # s, e = gt_start / duration, gt_end / duration
        s, e = gt_start, gt_end
        if parsed_times:
            start_time, end_time = parsed_times
            from_number = start_time
            to_number = end_time

            if s < e:
                if start_time > end_time:
                    reward = -0.1
                else:
                    intersection = max(0, min(to_number, e) - max(from_number, s))
                    union = max(to_number, e) - min(from_number, s)
                    if union > 0:
                        iou = intersection / union   # 0.1 0.3
                    else:
                        iou = 1. 
                    reward = iou
            else:
                if start_time > end_time:
                    reward = 0.5
                else:
                    reward = 0.

        # print('gt second:', gt_start, gt_end)
        # print('pred second:', start_time, end_time)
        # print(f"------------- {current_time} IoU reward: {reward} -------------\n")

        rewards.append(reward)

        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(f"Prompt: {prompt[0]['content'][0]['text']}\n")
                f.write(f"Content: {content}\n")
                f.write(f"pred second: {str(start_time)}, {str(end_time)}\n")
                f.write(f"gt second: {str(gt_start)}, {str(gt_end)}\n")
                f.write(f"------------- {current_time} IoU reward: {reward} -------------\n") # Modified log message

    return rewards



def gt_pos(completions, solution, durations, **kwargs):
    rewards = []
    for prompt, content, sol, duration in zip(kwargs['prompts'], completions, solution, durations):
        gt_start, gt_end = sol
        if gt_start > gt_end:
            rewards.append(0.)
        else:
            rewards.append(1.)
    return rewards


def pred_pos(completions, solution, durations, **kwargs):
    rewards = []
    for prompt, content, sol, duration in zip(kwargs['prompts'], completions, solution, durations):
        parsed_times = parse_timestamp_output(content)
        if parsed_times:
            start_time, end_time = parsed_times
            if start_time > end_time:
                rewards.append(0.)
            else:
                rewards.append(1.)
        else:
            rewards.append(0.)
    return rewards

def has_neg_prompt(completions, solution, durations, **kwargs):
    rewards = []
    for prompt, content, sol, duration in zip(kwargs['prompts'], completions, solution, durations):
        if '''If no relevant video clip exists, indicate this by returning "<timestamp> -1 to -2 </timestamp>".''' in prompt[0]['content'][0]['text']:
            rewards.append(1.)
        else:
            rewards.append(0.)
    return rewards

def extract(text):
    '''Find the first alphabet character and return its uppercase version'''
    for char in text:
        if char.isalpha():
            return char.upper()
    return 'Z'

def qa_accuracy(completions, solution, durations, **kwargs):
    rewards = []
    for prompt, content, sol, duration, source, ans in zip(kwargs['prompts'], completions, solution, durations, kwargs['source'], kwargs['qa_answer']):
        if source == 'gqa' and '<answer>' in content and '</answer>' in content and sol[0] < sol[1]:
            correct = extract(ans) == extract(content.split('<answer>')[-1].split('</answer>')[0])
            rewards.append(float(correct))
        else:
            rewards.append(0.)
    return rewards

def good_gqa(completions, solution, **kwargs):
    base_result = [sol[0] < sol[1] for sol in solution]
    revised_result = [bs * (1.0 if (sc == 'gqa') else 0.0) for bs, sc, ct in zip(base_result, kwargs['source'], completions)]
    return revised_result

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""

    pattern = re.compile(r'<think>.*?</think>.*?<timestamp>.*?</timestamp>', re.DOTALL)
    matches = [re.fullmatch(pattern, content.strip()) for content in completions]
    base_result = [1.0 if match else 0.0 for match in matches]
    
    revised_result = [bs * (1.0 if (sc == 'tg' or ('<answer>' in ct and '</answer>' in ct)) else 0.0) for bs, sc, ct in zip(base_result, kwargs['source'], completions)]
    return revised_result

reward_funcs_registry = {
    "iou": iou_timestamp_reward, # Modified registry to use iou_timestamp_reward
    "format": format_reward,
    "distance": distance_reward,
    "gt_pos": gt_pos,
    "pred_pos": pred_pos,
    "has_neg_prompt": has_neg_prompt,
    "qa_accuracy": qa_accuracy,
    "good_gqa": good_gqa,
}



def load_json_dataset(train_data_path, eval_data_path, video_folder):#, preprocessed_data_path=None): # Modified to accept preprocessed_data_path
    def create_dataset_from_json(file_path, split_name):
        
        assert file_path.endswith('.yaml')

        with open(file_path, 'r') as file:
            config = yaml.safe_load(file)

        all_examples = []

        # 处理每个数据集
        for dataset in config['datasets']:
            name = dataset['name']
            json_path = dataset['json_path']
            video_dir = dataset['video_dir']
            strategy = dataset['strategy']
            question_type = dataset['question_type']
            prompt = dataset['prompt']

            examples = []

            with open(json_path, 'r', encoding="utf-8") as f:
                data = json.load(f)
                if isinstance(data, dict):
                    for video_id, video_data in data.items():
                        for i in range(len(video_data['sentences'])):
                            sentence = video_data['sentences'][i]
                            timestamps = video_data['timestamps'][i]
                            example = {
                                "problem": prompt.replace('[bm.o.O.md]', sentence),
                                "solution": (timestamps[0], timestamps[1]),
                                "video_path": os.path.join(video_dir, f"{video_id}.mp4"),
                                "durations": video_data['duration'],
                                "querystamps": (-1, -2),
                                'source': question_type,
                                'dataset_name': name,
                                'qa_answer': 'Z',
                            }
                            if 'queryvideo' in video_data:
                                querystamps = video_data['queryvideo'][i]
                                example['durations'] = querystamps[1] - querystamps[0]
                                example['querystamps'] = (querystamps[0], querystamps[1])
                            examples.append(example)
                else:
                    assert isinstance(data, list)
                    for video_data in data:
                        question_str = video_data['question']
                        options_str = '\n'.join(video_data['options'])
                        sentence = f'''"{question_str}
{options_str}"'''
                        timestamps = video_data['glue'][0]
                        example = {
                            "problem": prompt.replace('[bm.o.O.md]', sentence),
                            "solution": (timestamps[0], timestamps[1]),
                            'video_path': os.path.join(video_dir, video_data['video']),
                            'source': question_type,
                            'dataset_name': name,
                            'qa_answer': (video_data['answer'] if timestamps[0] < timestamps[1] else 'Z'),
                        }
                        if 'queryvideo' in video_data:
                            example['durations'] = video_data['queryvideo'][1] - video_data['queryvideo'][0]
                            example['querystamps'] = (video_data['queryvideo'][0], video_data['queryvideo'][1])
                        else:
                            example['durations'] = video_data['duration']
                            example['querystamps'] = (-1, -2)
                        examples.append(example)
            
            random.seed(42)
            random.shuffle(examples)
            if strategy.startswith('random:'):
                rat = int(strategy.split('random:')[-1])
                examples = examples[:rat]
            
            print(f'!!! load {len(examples)} items from {name}')
            
            all_examples.extend(examples)

        random.shuffle(all_examples)
        print(len(all_examples))
        print(all_examples[:5])
        dataset = Dataset.from_list(all_examples)

        def __getitem__(self, idx): # Define getitem within the scope where dataset is available
            # idx = idx[0]

            example = dataset[idx]
            data_to_return = {k: v for k, v in example.items()} # Create a copy to avoid modifying original dataset


            total_pixels = eval(os.environ.get('PARAM_TOTAL_PIXELS', "3584 * 28 * 28"))
            min_pixels = eval(os.environ.get('PARAM_MIN_PIXELS', "16 * 28 * 28"))
            try:
                messages = [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "video",
                                "video": example["video_path"][0],
                                "total_pixels": total_pixels,
                                "min_pixels": min_pixels,
                            },
                        ]
                    }
                ]

                qs1, qs2 = example["querystamps"][0][0], example["querystamps"][0][1]

                if qs1 > qs2:
                    assert qs1 == -1 and qs2 == -2
                else:
                    messages[0]['content'][0]['video_start'] = qs1
                    messages[0]['content'][0]['video_end'] = qs2

                if "FPS_MAX_FRAMES" in os.environ:
                    messages[0]['content'][0]['max_frames'] = float(os.environ['FPS_MAX_FRAMES'])

                image_inputs, video_inputs, video_kwargs = process_vision_info([messages], return_video_kwargs=True)
                fps_inputs = video_kwargs['fps']
                data_to_return["video_inputs"] = [video_inputs]
                data_to_return["video_kwargs"] = [video_kwargs]
                data_to_return["use_preprocessed"] = [True] # Flag to indicate preprocessed data is used
            except Exception as e:
                print(f"Warning: Error loading preprocessed data from {example['video_path'][0]}, falling back to video_path. Error: {e}")
                data_to_return["use_preprocessed"] = [False] # Fallback to video_path if loading fails
                print(idx)
                idx = idx + 1
                return self.__getitem__(idx)

            return data_to_return

        dataset.__getitem__ = __getitem__.__get__(dataset, Dataset) # Bind getitem to the dataset

        return dataset

    train_dataset = create_dataset_from_json(train_data_path, "train")
    eval_dataset = None
    return DatasetDict({"train": train_dataset, "eval": eval_dataset})

def main(script_args, training_args, model_args):
    # Get reward functions
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]

    # # Load the dataset
    # dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
    # Load the dataset, now handles both raw and preprocessed data
    dataset = load_json_dataset(
        script_args.train_data_path,
        script_args.eval_data_path,
        script_args.video_folder,
        # script_args.preprocessed_data_path # Pass preprocessed_data_path
    )



    if not training_args.use_vllm:
        trainer_cls = Qwen2VLGRPOTrainer
    else:
        raise NotImplementedError
        trainer_cls = Qwen2VLGRPOVLLMTrainer

    print("using: ", trainer_cls)


    # Initialize the GRPO trainer
    trainer = trainer_cls(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        peft_config=get_peft_config(model_args),
        attn_implementation=model_args.attn_implementation,
        max_pixels=script_args.max_pixels,
        min_pixels=script_args.min_pixels,
    )

    # Train and push the model to the Hub
    # check if training_args.output_dir exists subfolder checkpoint-*
    has_ckpt = False
    if os.path.exists(training_args.output_dir):
        checkpoints = [d for d in os.listdir(training_args.output_dir) if d.startswith("checkpoint-")]
        has_ckpt = len(checkpoints) > 0
    if has_ckpt:
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)