import argparse
import torch
import os
import json
from tqdm import tqdm
from transformers import AutoModel, AutoProcessor
from qwen_vl_utils import process_vision_info
import math
import random


def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks randomly"""

    shuffled = lst[:]  

    random.shuffle(shuffled)
    
    chunk_size = math.ceil(len(shuffled) / n)
    
    chunks = [shuffled[i:i+chunk_size] for i in range(0, len(shuffled), chunk_size)]
    
    return chunks


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


class RWKVQwenVNBenchEvaluator:
    def __init__(self, model_path):
        """Initialize the RWKV-Qwen hybrid model for VNBench evaluation."""
        self.model_path = model_path
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        print(f"Loading model from: {model_path}")
        
        # Load model and processor
        self.model = AutoModel.from_pretrained(
            model_path, 
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="cuda",
            trust_remote_code=True
        )
        self.model.to(self.device)
        
        self.processor = AutoProcessor.from_pretrained(model_path)
        
        print(f"Model loaded successfully on {self.device}")
        
    def prepare_video_messages(self, video_path, question):
        """Prepare input messages for video question answering."""
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "video",
                        "video": video_path,
                    },
                    {"type": "text", "text": question},
                ],
            }
        ]
        return messages
    
    def generate_response(self, video_path, question, max_new_tokens=1):
        """Generate response for a video question."""
        try:
            # Prepare messages
            messages = self.prepare_video_messages(video_path, question)
            
            # Apply chat template
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            # Process vision info
            image_inputs, video_inputs = process_vision_info(messages)
            
            # Prepare inputs
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            
            # Move to device
            inputs = inputs.to(self.device)
            
            # Generate response
            with torch.no_grad():
                generated_ids = self.model.generate(
                    **inputs, 
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    top_p=0.001,
                    top_k=1,
                    temperature=0.01,
                    repetition_penalty=1.0,
                )
            
            # Decode the response
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            
            return output_text[0].strip()
            
        except Exception as e:
            print(f"Error processing video {video_path}: {e}")
            return f"Error: {str(e)}"


def eval_model(args):
    """
    Run inference on VNBench dataset using the RWKV-Qwen hybrid model.
    
    Args:
        args: Command-line arguments.
    """
    # Initialize the model
    evaluator = RWKVQwenVNBenchEvaluator(args.model_path)
    
    # Load ground truth file
    with open(args.gt_file) as file:
        gt_questions = json.load(file)
    
    print(f"Total questions: {len(gt_questions)}")
    
    # Get chunk for distributed evaluation
    gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
    
    # Create the output directory if it doesn't exist
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    
    # Determine output filename
    if args.num_chunks > 1:
        output_name = f"{args.num_chunks}_{args.chunk_idx}"
    else:
        output_name = args.output_name
    
    answers_file = os.path.join(args.output_dir, f"{output_name}.jsonl")
    
    with open(answers_file, "w") as ans_file:
        for sample in tqdm(gt_questions, desc="Processing VNBench videos"):
            # Extract sample information
            video_path = sample['video']
            question = sample['question']
            options = sample['options']
            ground_truth = sample['gt']
            ground_truth_option = sample['gt_option']
            needle_time = sample.get('needle_time', [])
            length = sample.get('length', 0)
            question_type = sample.get('type', '')
            try_count = sample.get('try', 0)
            
            # Format options as multiple choice
            options_text = ""
            for i, option in enumerate(options):
                options_text += f"{chr(65 + i)}. {option}\n"
            
            # Create prompt with multiple choice format
            option_prompt = "Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option."
            full_question = question + "\n" + options_text.strip()
            post_prompt = "The best answer is:"
            full_prompt = option_prompt + "\n" + full_question + "\n" + post_prompt
            
            # Extract video filename from path (remove ./VNBench/ prefix)
            video_filename = os.path.basename(video_path)
            video_full_path = os.path.join(args.video_dir, video_filename)
            
            # Check if video file exists
            if not os.path.exists(video_full_path):
                print(f"Warning: Video file not found: {video_full_path}")
                model_answer = "Video file not found"
            else:
                # Generate model response
                model_answer = evaluator.generate_response(
                    video_path=video_full_path,
                    question=full_prompt,
                    max_new_tokens=args.max_new_tokens
                )
                print("model_answer:", model_answer)
            
            # Prepare result
            result = {
                'video': video_path,
                'video_filename': video_filename,
                'question': question,
                'options': options,
                'gt': ground_truth,
                'gt_option': ground_truth_option,
                'model_answer': model_answer,
                'needle_time': needle_time,
                'length': length,
                'type': question_type,
                'try': try_count
            }
            
            # Write result
            ans_file.write(json.dumps(result) + "\n")
            ans_file.flush()
            
            # Clear cache
            torch.cuda.empty_cache()
    
    print(f"Evaluation completed. Results saved to: {answers_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="VNBench Evaluation with RWKV-Qwen Hybrid Model")
    
    # Define command-line arguments
    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
    parser.add_argument('--gt_file', help='Path to the ground truth file containing questions.', required=True)
    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', default="pred")
    parser.add_argument("--model-path", type=str, required=True, help="Path to the RWKV-Qwen model")
    parser.add_argument("--num-chunks", type=int, default=1, help="Number of chunks for distributed evaluation")
    parser.add_argument("--chunk-idx", type=int, default=0, help="Chunk index for distributed evaluation")
    parser.add_argument("--max-new-tokens", type=int, default=1, help="Maximum number of new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation")
    parser.add_argument("--top-p", type=float, default=0.8, help="Top-p for generation")
    
    args = parser.parse_args()
    
    eval_model(args)