import json
import argparse
import subprocess
import torch
import cv2
import os
import requests
import copy
import warnings
import numpy as np

from PIL import Image
from typing import Dict, Any
from pathlib import Path
from decord import VideoReader, cpu
from operator import attrgetter
from model.llava.model.builder import load_pretrained_model
from model.llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from model.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from model.llava.conversation import conv_templates, SeparatorStyle

def load_video(video_path, max_frames_num):
    if type(video_path) == str:
        vr = VideoReader(video_path, ctx=cpu(0))
    else:
        vr = VideoReader(video_path[0], ctx=cpu(0))
    total_frame_num = len(vr)
    uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
    frame_idx = uniform_sampled_frames.tolist()
    spare_frames = vr.get_batch(frame_idx).asnumpy()
    return spare_frames  # (frames, height, width, channels)

class ModelWrapper:
    def __init__(self, model_type: str, model_path: str):
        self.model_type = model_type.lower()
        
        if self.model_type == "videollama":
            #from videollama import VideoLlama  
            return
            self.model = VideoLlama(model_path)
            self.processor = VideoLlama.get_processor()
            
        elif self.model_type =="qwen2.5-7b":
            pass
        
        elif self.model_type =="llava-onevision-qwen2-7b-ov":
            llava_model_args = {
                "multimodal": True,
            }
            self.tokenizer, self.model, self.image_processor, self.max_length = load_pretrained_model(model_path, None, self.model_type, device_map= "cuda", attn_implementation="sdpa", **llava_model_args)

            self.model.eval()
            
        elif self.model_type == "videochatgpt":
            #from video_chatgpt import VideoChatGPT 
            return
            self.model = VideoChatGPT.from_pretrained(model_path)
            self.processor = VideoChatGPT.get_processor()
            
        else:
            raise NotImplementedError(f"Unsupported model: {model_type}")

    def predict(self, video_path: str, question: str, options: Dict[str, str]) -> str:
        image_tensors = []
        frames = self._extract_frames(video_path)

        
        if self.model_type == "videollama":
            prompt = self._build_prompt(question, options)
            raw_output = self.model.generate(frames, prompt)
            return self._parse_output(raw_output, options)
        elif self.model_type=="":
            conv_template = "qwen_1_5"
            frames = self.image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].half().cuda()
            image_tensors.append(frames)
            conv = copy.deepcopy(conv_templates[conv_template])
            conv.append_message(conv.roles[0], question)
            conv.append_message(conv.roles[1], None)
            prompt_question = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to("cuda")
            image_sizes = [frame.size for frame in frames]

            # Generate response
            cont = self.model.generate(
                input_ids,
                images=image_tensors,
                image_sizes=image_sizes,
                do_sample=False,
                temperature=0,
                max_new_tokens=4096,
                modalities=["video"],
            )
            text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
            return text_outputs[0]
        elif self.model_type == "videochatgpt":
            inputs = self.processor(
                videos=[frames], 
                text=question,
                options=options,
                return_tensors="pt"
            )
            outputs = self.model(**inputs)
            return outputs['prediction']
            
    def _extract_frames(self, video_path: str) -> list:
        return load_video(video_path,16)

    def _build_prompt(self, question: str, options: Dict[str, str]) -> str:
        options_text = "\n".join([f"{k}. {v}" for k,v in options.items()])
        return f"{question}\n{options_text}\nAnswer:"

    def _parse_output(self, raw: str, options: Dict[str, str]) -> str:
        for c in ['A', 'B', 'C', 'D']:
            if c in raw[:5]:  
                return c
        return "Unknown"


def evaluate(args):
    with open(args.annotation) as f:
        samples = json.load(f)
    
    model = ModelWrapper(args.model_type, args.model_path)
    
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    for i, sample in enumerate(samples):
        video_name = f"{Path(sample['video']).stem}_{sample['end_time']}.mp4"
        clip_path = output_dir / video_name
        
        if not clip_path.exists():
            _crop_video(
                src=Path(os.path.join(args.video_root, sample['field'] , sample))['video'],
                dst=clip_path,
                end_time=sample['end_time']
            )
        
        try:
            pred = model.predict(
                video_path=str(clip_path),
                question=sample['question'],
                options=sample['options']
            )
            sample['pred'] = pred
            sample['correct'] = (pred == sample['answer'])
        except Exception as e:
            print(f"Error processing {sample['questionID']}: {str(e)}")
            sample['pred'] = "ERROR"
        
        if i % 10 == 0:
            _save_results(samples, args.output_dir)
    
    _save_results(samples, args.output_dir)
    print(f"Evaluation completed. Results saved to {args.output_dir}")

def _crop_video(src: Path, dst: Path, end_time: float):
    cmd = [
        "ffmpeg", "-y",
        "-ss", "0",
        "-t", str(end_time),
        "-i", str(src),
        "-c", "copy",
        str(dst)
    ]
    subprocess.run(cmd, check=True)

def _save_results(data: list, output_dir: str):
    with open(Path(output_dir)/"results.json", "w") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="VideoQA Evaluation")
    parser.add_argument("--annotation", type=str, required=True,
                      help="Path to constructed annotation JSON")
    parser.add_argument("--video_root", type=str, required=True,
                      help="Root directory of original videos")
    parser.add_argument("--output_dir", type=str, required=True,
                      help="Output directory for results and clips")
    parser.add_argument("--model_type", type=str, required=True,
                      choices=["videollama", "videochatgpt"],
                      help="Type of model to evaluate")
    parser.add_argument("--model_path", type=str, required=True,
                      help="Path to model weights/config")
    
    args = parser.parse_args()
    evaluate(args)