import argparse
import torch
import os
import json
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import math


def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


class RWKVQwenVideoEvaluator:
    def __init__(self, model_path):
        """Initialize the RWKV-Qwen hybrid model for video evaluation."""
        self.model_path = model_path
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        print(f"Loading model from: {model_path}")
        
        # Load model and processor
        self.model = AutoModel.from_pretrained(
            model_path, 
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="cuda",
            trust_remote_code=True
        )
        self.model.to(self.device)
        
        self.processor = AutoProcessor.from_pretrained(model_path)
        
        print(f"Model loaded successfully on {self.device}")
        
    def prepare_video_messages(self, video_path, question):
        """Prepare input messages for video question answering."""
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "video",
                        "video": video_path,
                    },
                    {"type": "text", "text": question},
                ],
            }
        ]
        return messages
    
    def generate_response(self, video_path, question, max_new_tokens=1):
        """Generate response for a video question."""
        try:
            # Prepare messages
            messages = self.prepare_video_messages(video_path, question)
            
            # Apply chat template
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            # Process vision info
            image_inputs, video_inputs = process_vision_info(messages)
            
            # Prepare inputs
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            
            # Move to device
            inputs = inputs.to(self.device)
            
            # Generate response
            with torch.no_grad():
                generated_ids = self.model.generate(
                    **inputs, 
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    top_p=0.001,
                    top_k=1,
                    temperature=0.01,
                    repetition_penalty=1.0,
                )
            
            # Decode the response
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            
            return output_text[0].strip()
            
        except Exception as e:
            print(f"Error processing video {video_path}: {e}")
            return f"Error: {str(e)}"


def eval_model(args):
    """
    Run inference on VideoMME dataset using the RWKV-Qwen hybrid model.
    
    Args:
        args: Command-line arguments.
    """
    # Initialize the model
    evaluator = RWKVQwenVideoEvaluator(args.model_path)
    
    # Check if we're doing combined evaluation
    if hasattr(args, 'short_gt_file') and hasattr(args, 'medium_gt_file'):
        # Load ground truth files for short and medium
        with open(args.short_gt_file) as file:
            short_questions = json.load(file)
        
        with open(args.medium_gt_file) as file:
            medium_questions = json.load(file)
        
        # Combine the questions
        gt_questions = short_questions + medium_questions
        video_range = "short_medium_combined"
        print(f"Total questions: {len(gt_questions)} (short: {len(short_questions)}, medium: {len(medium_questions)})")
    else:
        # Load single ground truth file
        with open(args.gt_file) as file:
            gt_questions = json.load(file)
        
        # Determine video length range
        video_range = ""
        if "short" in args.gt_file:
            video_range = "short"
        elif "medium" in args.gt_file:
            video_range = "medium"
        elif "long" in args.gt_file:
            video_range = "long"
    
    # Get chunk for distributed evaluation
    gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
    
    # Create the output directory if it doesn't exist
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    
    # Determine output filename
    if args.num_chunks > 1:
        output_name = f"{args.num_chunks}_{args.chunk_idx}"
    else:
        output_name = args.output_name
    
    answers_file = os.path.join(args.output_dir, f"{output_name}.jsonl")
    
    with open(answers_file, "w") as ans_file:
        for sample in tqdm(gt_questions, desc=f"Processing {video_range} videos"):
            option_prompt = "Select the best answer to the following multiple-choice question based on the video and the subtitles. Respond with only the letter (A, B, C, or D) of the correct option."
            # Extract sample information
            video_id = sample['videoID']
            question = sample['question']
            answer = sample['answer']
            option = str(sample["options"])
            duration = sample['duration']
            answer_index = sample.get('answer_index', '')
            question = question + "\n" + option
            post_prompt = "The best answer is:"
            full_prompt = option_prompt + "\n" + question + "\n" + post_prompt
            
            
            # Construct video path
            video_path = os.path.join(args.video_dir, f"{video_id}.mp4")
            
            # Check if video file exists
            if not os.path.exists(video_path):
                print(f"Warning: Video file not found: {video_path}")
                model_answer = "Video file not found"
            else:
                # Generate model response
                model_answer = evaluator.generate_response(
                    video_path=video_path,
                    question=full_prompt,
                    max_new_tokens=args.max_new_tokens
                )
                print("model_answer:",model_answer)
            
            # Prepare result
            result = {
                'video_id': video_id,
                'question': question,
                'answer': answer,
                'answer_index': answer_index,
                'model_answer': model_answer,
                'duration': duration,
                'source': sample.get('source', video_range)
            }
            
            # Write result
            ans_file.write(json.dumps(result) + "\n")
            ans_file.flush()
            
            # Clear cache
            torch.cuda.empty_cache()
    
    print(f"Evaluation completed. Results saved to: {answers_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="VideoMME Evaluation with RWKV-Qwen Hybrid Model")
    
    # Define command-line arguments
    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
    parser.add_argument('--gt_file', help='Path to the ground truth file containing questions.', required=True)
    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', default="pred")
    parser.add_argument("--model-path", type=str, required=True, help="Path to the RWKV-Qwen model")
    parser.add_argument("--num-chunks", type=int, default=1, help="Number of chunks for distributed evaluation")
    parser.add_argument("--chunk-idx", type=int, default=0, help="Chunk index for distributed evaluation")
    parser.add_argument("--max-new-tokens", type=int, default=1, help="Maximum number of new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation")
    parser.add_argument("--top-p", type=float, default=0.8, help="Top-p for generation")
    
    args = parser.parse_args()
    
    eval_model(args)