import os
import json
import random
import requests
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForVision2Seq,
    AutoProcessor,
    BitsAndBytesConfig,
    Qwen2VLProcessor,
    Qwen2VLForConditionalGeneration,
    Qwen2_5_VLForConditionalGeneration
)
from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
)
from accelerate import Accelerator
from qwen_vl_utils import process_vision_info

from datasets import Dataset, DatasetDict

import wandb

from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field


@dataclass
class SFTScriptArguments(ScriptArguments):
    video_path: Optional[str] = field(
        default='{}',
        metadata={"help": "Video path config: JSON string for multiple data sources, e.g., '{\"STAR\": \"/path/\"}'"}
    )


def parse_video_path_config(video_path_arg):
    if not video_path_arg or video_path_arg == '{}':
        return {}
    if video_path_arg.startswith('{') and video_path_arg.endswith('}'):
        try:
            return json.loads(video_path_arg)
        except json.JSONDecodeError:
            print(f"Warning: Failed to parse video_path as JSON: {video_path_arg}")
            return {}
    return {}


def get_video_path(video_path_config, data_source):
    if isinstance(video_path_config, dict) and data_source in video_path_config:
        return video_path_config[data_source]
    return ""


def get_current_device():
    return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"


def download_video(url: str, folder: str = '/tmp/videos/') -> str:
    filename = url.split("/")[-1]
    local_path = os.path.join(folder, filename)

    if os.path.exists(local_path):
        return local_path

    try:
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            with open(local_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
        return local_path
    except requests.RequestException as e:
        raise Exception(f"Failed to download video: {e}")


def prepare_dataset(example: Dict[str, Any], video_path_config: dict) -> Dict[str, List[Dict[str, Any]]]:
    system_message = (
        "You are an expert video analyst specializing in temporal reasoning over long videos. "
        "When analyzing videos, you must construct structured reasoning traces that mirror the video's temporal structure. "
        "For each critical temporal segment, specify time intervals with `<time>start_time-end_time</time>`, "
        "describe visual evidence with `<caption>key visual elements</caption>`, "
        "and provide temporal analysis with `<think>reasoning about temporal relationships</think>`. "
        "Your reasoning should capture ordered event chains and temporal dependencies across segments. "
        "Employ natural cognitive expressions to articulate your temporal understanding process. "
        "After examining temporal traces, synthesize your findings and place the final answer in `<answer> </answer>` tags."
    )

    QUESTION_TEMPLATE = (
        "{Question}\n\n"
        "Analyze the video by constructing temporal reasoning traces. Identify key temporal segments and their relationships "
        "using `<time> </time>`, `<caption> </caption>`, `<think> </think>` tags. "
        "Conduct temporal analysis to derive your answer, then provide only the single option letter within `<answer> </answer>` tags."
    )

    if example["problem_type"] == 'multiple choice':
        question = example['problem'] + "Options:\n"
        for op in example["options"]:
            question += op + "\n"
    else:
        question = example['problem']

    video_base_path = get_video_path(video_path_config, example['data_source'])
    video_path = os.path.join(video_base_path, example["path"]) if video_base_path else example["path"]

    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}]
        },
        {
            "role": "user",
            "content": [
                {
                    "type": example['data_type'],
                    "video": video_path,
                    "nframes": 16,
                    "max_pixels": 128 * 28 * 28,
                },
                {
                    "type": "text",
                    "text": QUESTION_TEMPLATE.format(Question=question)
                }
            ]
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": example['process'] + "\n" + example['solution']}]
        }
    ]

    return {"messages": messages}


def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
    texts = []

    for i, example in enumerate(examples):
        try:
            video_path = None
            for message in example["messages"]:
                if message["role"] == "user":
                    for content in message["content"]:
                        if content.get("type") == "video":
                            video_path = content.get("video")
                            print(f"Processing video: {video_path}")
                            break
            texts.append(processor.apply_chat_template(example["messages"], tokenize=False))
            image_inputs, video_inputs, video_kwargs = process_vision_info(example["messages"],
                                                                           return_video_kwargs=True)

        except Exception as e:
            raise ValueError(f"Failed to process example {i}: {e}")

    inputs = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        return_tensors="pt",
        padding=True
    )

    labels = inputs["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100

    visual_tokens = [151652, 151653, 151656] if isinstance(processor, Qwen2VLProcessor) else [
        processor.tokenizer.convert_tokens_to_ids(processor.image_token)
    ]

    for visual_token_id in visual_tokens:
        labels[labels == visual_token_id] = -100

    inputs["labels"] = labels
    return inputs


if __name__ == "__main__":
    parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse_args_and_config()
    
    video_path_config = parse_video_path_config(script_args.video_path)

    training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
    training_args.remove_unused_columns = False
    training_args.dataset_kwargs = {"skip_prepare_dataset": True}

    if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
        dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
    else:
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

    torch_dtype = (
        model_config.torch_dtype
        if model_config.torch_dtype in ["auto", None]
        else getattr(torch, model_config.torch_dtype)
    )

    model_kwargs = dict(
        revision=model_config.model_revision,
        trust_remote_code=model_config.trust_remote_code,
        torch_dtype=torch_dtype,
        device_map=get_kbit_device_map(),
    )

    if "Qwen2-VL" in model_config.model_name_or_path:
        model = Qwen2VLForConditionalGeneration.from_pretrained(model_config.model_name_or_path, **model_kwargs)
    elif "Qwen2.5-VL" in model_config.model_name_or_path:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_config.model_name_or_path, **model_kwargs)
    else:
        model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)

    processor = AutoProcessor.from_pretrained(
        model_config.model_name_or_path,
        trust_remote_code=model_config.trust_remote_code
    )

    prepared_dataset = [prepare_dataset(example, video_path_config) for example in dataset['train']]

    if training_args.report_to == "wandb":
        wandb.init(project="videotrace-r1-sft")

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=prepared_dataset,
        data_collator=collate_fn,
        peft_config=get_peft_config(model_config),
    )

    trainer.train()

    trainer.save_model(training_args.output_dir)
    processor.save_pretrained(training_args.output_dir)

    if trainer.accelerator.is_main_process:
        trainer.model.config.use_cache = True
        trainer.model.config.save_pretrained(training_args.output_dir)

    del model
    del trainer
    torch.cuda.empty_cache()
    wandb.finish()
