import argparse
import torch
import os
import json
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import math
import random

def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks randomly"""

    shuffled = lst[:]  

    random.shuffle(shuffled)
    
    chunk_size = math.ceil(len(shuffled) / n)
    
    chunks = [shuffled[i:i+chunk_size] for i in range(0, len(shuffled), chunk_size)]
    
    return chunks


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


class RWKVQwenLVBenchEvaluator:
    def __init__(self, model_path):
        """Initialize the RWKV-Qwen hybrid model for LVBench evaluation."""
        self.model_path = model_path
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        print(f"Loading model from: {model_path}")
        
        # Load model and processor
        self.model = AutoModel.from_pretrained(
            model_path, 
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="cuda",
            trust_remote_code=True
        )
        self.model.to(self.device)
        
        self.processor = AutoProcessor.from_pretrained(model_path)
        
        print(f"Model loaded successfully on {self.device}")
        
    def prepare_video_messages(self, video_path, question):
        """Prepare input messages for video question answering."""
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "video",
                        "video": video_path,
                    },
                    {"type": "text", "text": question},
                ],
            }
        ]
        return messages
    
    def generate_response(self, video_path, question, max_new_tokens=1):
        """Generate response for a video question."""
        try:
            # Prepare messages
            messages = self.prepare_video_messages(video_path, question)
            
            # Apply chat template
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            # Process vision info
            image_inputs, video_inputs = process_vision_info(messages)
            
            # Prepare inputs
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            
            # Move to device
            inputs = inputs.to(self.device)
            
            # Generate response
            with torch.no_grad():
                generated_ids = self.model.generate(
                    **inputs, 
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    top_p=0.001,
                    top_k=1,
                    temperature=0.01,
                    repetition_penalty=1.0,
                )
            
            # Decode the response
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            
            return output_text[0].strip()
            
        except Exception as e:
            print(f"Error processing video {video_path}: {e}")
            return f"Error: {str(e)}"


def load_and_flatten_jsonl(jsonl_file):
    """Load JSONL file and flatten it to extract all questions."""
    all_questions = []
    
    with open(jsonl_file, 'r') as file:
        for line_num, line in enumerate(file, 1):
            try:
                data = json.loads(line.strip())
                video_key = data['key']
                video_type = data.get('type', 'unknown')
                
                # Extract all questions for this video
                for qa_item in data['qa']:
                    question_data = {
                        'video_key': video_key,
                        'video_type': video_type,
                        'uid': qa_item['uid'],
                        'question': qa_item['question'],
                        'answer': qa_item['answer'],
                        'question_type': qa_item.get('question_type', []),
                        'time_reference': qa_item.get('time_reference', ''),
                        'video_info': data.get('video_info', {})
                    }
                    all_questions.append(question_data)
                    
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON on line {line_num}: {e}")
            except KeyError as e:
                print(f"Missing key {e} on line {line_num}")
    
    return all_questions


def eval_model(args):
    """
    Run inference on LVBench dataset using the RWKV-Qwen hybrid model.
    
    Args:
        args: Command-line arguments.
    """
    # Initialize the model
    evaluator = RWKVQwenLVBenchEvaluator(args.model_path)
    
    # Load and flatten JSONL file
    print("Loading and flattening JSONL data...")
    all_questions = load_and_flatten_jsonl(args.gt_file)
    print(f"Total questions: {len(all_questions)}")
    
    # Get chunk for distributed evaluation
    questions_chunk = get_chunk(all_questions, args.num_chunks, args.chunk_idx)
    print(f"Processing chunk {args.chunk_idx}/{args.num_chunks} with {len(questions_chunk)} questions")
    
    # Create the output directory if it doesn't exist
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    
    # Determine output filename
    if args.num_chunks > 1:
        output_name = f"{args.num_chunks}_{args.chunk_idx}"
    else:
        output_name = args.output_name
    
    answers_file = os.path.join(args.output_dir, f"{output_name}.jsonl")
    
    with open(answers_file, "w") as ans_file:
        for sample in tqdm(questions_chunk, desc="Processing LVBench videos"):
            # Extract sample information
            video_key = sample['video_key']
            video_type = sample['video_type']
            uid = sample['uid']
            question = sample['question']
            answer = sample['answer']
            question_type = sample.get('question_type', [])
            time_reference = sample.get('time_reference', '')
            
            # Create prompt (question already contains multiple choice options)
            option_prompt = "Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option."
            post_prompt = "The best answer is:"
            full_prompt = option_prompt + "\n" + question + "\n" + post_prompt
            
            # Construct video path
            video_path = os.path.join(args.video_dir, f"{video_key}.mp4")
            
            # Check if video file exists
            if not os.path.exists(video_path):
                print(f"Warning: Video file not found: {video_path}")
                model_answer = "Video file not found"
            else:
                # Generate model response
                model_answer = evaluator.generate_response(
                    video_path=video_path,
                    question=full_prompt,
                    max_new_tokens=args.max_new_tokens
                )
                print(f"model_answer: {model_answer}")
            
            # Prepare result
            result = {
                'video_key': video_key,
                'video_type': video_type,
                'uid': uid,
                'question': question,
                'answer': answer,
                'model_answer': model_answer,
                'question_type': question_type,
                'time_reference': time_reference,
                'video_info': sample.get('video_info', {})
            }
            
            # Write result
            ans_file.write(json.dumps(result) + "\n")
            ans_file.flush()
            
            # Clear cache
            torch.cuda.empty_cache()
    
    print(f"Evaluation completed. Results saved to: {answers_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="LVBench Evaluation with RWKV-Qwen Hybrid Model")
    
    # Define command-line arguments
    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
    parser.add_argument('--gt_file', help='Path to the ground truth JSONL file containing questions.', required=True)
    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', default="pred")
    parser.add_argument("--model-path", type=str, required=True, help="Path to the RWKV-Qwen model")
    parser.add_argument("--num-chunks", type=int, default=1, help="Number of chunks for distributed evaluation")
    parser.add_argument("--chunk-idx", type=int, default=0, help="Chunk index for distributed evaluation")
    parser.add_argument("--max-new-tokens", type=int, default=1, help="Maximum number of new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation")
    parser.add_argument("--top-p", type=float, default=0.8, help="Top-p for generation")
    
    args = parser.parse_args()
    
    eval_model(args)