import os
import re
import json
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset, load_from_disk
from transformers import Qwen2VLForConditionalGeneration

from trainer import Qwen2VLGRPOTrainer
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

from datasets import Dataset, DatasetDict

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer


@dataclass
class GRPOScriptArguments(ScriptArguments):
    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format", "trace_grounded"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format', 'trace_grounded'"},
    )
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image"},
    )
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image"},
    )
    video_path: Optional[str] = field(
        default='{}',
        metadata={"help": "Video path config: JSON string for multiple data sources, e.g., '{\"STAR\": \"/path/\"}'"}
    )


def accuracy_reward(completions, solution, **kwargs):
    def extract_answer(text):
        pattern = r'<answer>\s*(.*?)\s*</answer>'
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1).strip()
        return ""

    question_type = kwargs['problem_type'][0]

    contents = [completion[0]["content"] for completion in completions]
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    rewards = []

    for content, sol in zip(contents, solution):

        try:
            output_ans = extract_answer(content)
            gt_ans = extract_answer(sol)
            if question_type == "multiple choice":
                reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
            else:
                reward = 0.0
        except Exception as e:
            print(f"Error in reward_fn for question_type '{question_type}': {e}")
            reward = 0.0

        rewards.append(reward)

        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                f.write(f"Content: {content}\n")
                f.write(f"Solution: {sol}\n")

    return rewards

def format_reward(completions, **kwargs):
    import re

    def has_valid_format(content):
        blocks = re.findall(r'<time>.*?</time>\s*<caption>.*?</caption>\s*<think>.*?</think>', content, re.DOTALL)
        timesteps = re.findall(r'<time>.*?</time>', content, re.DOTALL)
        captions = re.findall(r'<caption>.*?</caption>', content, re.DOTALL)
        thinks = re.findall(r'<think>.*?</think>', content, re.DOTALL)
        answers = re.findall(r'<answer>.*?</answer>', content, re.DOTALL)

        return (len(blocks) > 0 and
                len(timesteps) == len(captions) == len(thinks) == len(blocks) and
                len(answers) == 1)

    completion_contents = [completion[0]["content"] for completion in completions]
    return [1.0 if has_valid_format(content) else 0.0 for content in completion_contents]


def trace_grounded_reward(completions, **kwargs):
    import re
    
    def extract_temporal_claims(content):
        time_spans = re.findall(r'<time>(.*?)</time>', content, re.DOTALL)
        captions = re.findall(r'<caption>(.*?)</caption>', content, re.DOTALL)
        thinks = re.findall(r'<think>(.*?)</think>', content, re.DOTALL)
        
        claims = []
        for t, c, th in zip(time_spans, captions, thinks):
            claims.append({
                'time': t.strip(),
                'caption': c.strip(),
                'think': th.strip()
            })
        return claims
    
    def compute_temporal_complexity(problem):
        problem_lower = problem.lower()
        
        multi_hop_keywords = ['between', 'before', 'after', 'first', 'last', 'sequence', 'order']
        pairwise_keywords = ['when', 'during', 'while', 'as']
        
        multi_hop_count = sum(1 for kw in multi_hop_keywords if kw in problem_lower)
        pairwise_count = sum(1 for kw in pairwise_keywords if kw in problem_lower)
        
        if multi_hop_count >= 2:
            return 0.6
        elif pairwise_count >= 1 or multi_hop_count == 1:
            return 0.4
        else:
            return 0.2
    
    def verify_temporal_claim(claim):
        r_time = 0.0
        r_entity = 0.0
        r_order = 0.0
        r_chain = 0.0
        
        if claim['time'] and re.match(r'\d+.*-\d+.*', claim['time']):
            r_time = 0.1
        
        entity_keywords = ['person', 'object', 'man', 'woman', 'child', 'camera', 'scene']
        if any(ent in claim['caption'].lower() for ent in entity_keywords):
            r_entity = 0.1
        
        order_keywords = ['before', 'after', 'then', 'next', 'following', 'previous']
        if any(ord_kw in claim['think'].lower() for ord_kw in order_keywords):
            r_order = 0.2
        
        if len(claim['caption'].split()) > 5 and len(claim['think'].split()) > 3:
            r_chain = 0.3
        
        return r_time + r_entity + r_order + r_chain
    
    completion_contents = [completion[0]["content"] for completion in completions]
    problem = kwargs.get('problem', [''])[0]
    
    lambda_q = compute_temporal_complexity(problem)
    
    rewards = []
    for content in completion_contents:
        claims = extract_temporal_claims(content)
        
        if len(claims) == 0:
            rewards.append(0.0)
            continue
        
        total_trace_reward = 0.0
        for claim in claims:
            total_trace_reward += verify_temporal_claim(claim)
        
        trace_reward = total_trace_reward / len(claims) if len(claims) > 0 else 0.0
        
        weighted_reward = lambda_q * trace_reward
        
        rewards.append(weighted_reward)
    
    return rewards


reward_funcs_registry = {
    "accuracy": accuracy_reward,
    "format": format_reward,
    "trace_grounded": trace_grounded_reward,
}

def parse_video_path_config(video_path_arg):
    if not video_path_arg or video_path_arg == '{}':
        return {}
    if video_path_arg.startswith('{') and video_path_arg.endswith('}'):
        try:
            return json.loads(video_path_arg)
        except json.JSONDecodeError:
            print(f"Warning: Failed to parse video_path as JSON: {video_path_arg}")
            return {}
    return {}


SYSTEM_PROMPT = (
    "You are an expert video analyst specializing in temporal reasoning over long videos. "
    "When analyzing videos, you must construct structured reasoning traces that mirror the video's temporal structure. "
    "For each critical temporal segment, specify time intervals with `<time>start_time-end_time</time>`, "
    "describe visual evidence with `<caption>key visual elements</caption>`, "
    "and provide temporal analysis with `<think>reasoning about temporal relationships</think>`. "
    "Your reasoning should capture ordered event chains and temporal dependencies across segments. "
    "Employ natural cognitive expressions to articulate your temporal understanding process. "
    "After examining temporal traces, synthesize your findings and place the final answer in `<answer> </answer>` tags."
)


def main(script_args, training_args, model_args):
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
    
    video_path_config = parse_video_path_config(script_args.video_path)

    if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
        dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
    else:
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

    QUESTION_TEMPLATE = (
        "{Question}\n\n"
        "Analyze the video by constructing temporal reasoning traces. Identify key temporal segments and their relationships "
        "using `<time> </time>`, `<caption> </caption>`, `<think> </think>` tags. "
        "Conduct temporal analysis to derive your answer, then provide only the single option letter within `<answer> </answer>` tags."
    )

    def make_conversation_image_and_video(example):
        if example["problem_type"] == 'multiple choice':
            question = example['problem'] + "Options:\n"
            for op in example["options"]:
                question += op + "\n"
        else:
            question = example['problem']

        msg = {
            "prompt":
                [{
                    "role": "user",
                    "content": [
                        {
                            "type": example['data_type'],
                        },
                        {
                            "type": "text",
                            "text": QUESTION_TEMPLATE.format(Question=question)
                        }
                    ]
                }]
        }

        return msg

    dataset = dataset.map(make_conversation_image_and_video)

    trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
    print("using: ", trainer_cls)

    trainer = trainer_cls(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        script_args=script_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        peft_config=get_peft_config(model_args),
        attn_implementation=model_args.attn_implementation,
        max_pixels=script_args.max_pixels,
        min_pixels=script_args.min_pixels,
        video_path_config=video_path_config,
    )

    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
        trainer.train(resume_from_checkpoint=checkpoint)
    else:
        trainer.train()

    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)
