import argparse
import torch
import os
import json
from tqdm import tqdm
from transformers import AutoModel, AutoProcessor,Qwen2_5_VLForConditionalGeneration
from qwen_vl_utils import process_vision_info
import math
import random
import glob


def load_json_files_list(json_files_list):
    """Load and merge multiple JSON files from a list."""
    all_questions = []
    for json_file in json_files_list:
        print(f"Loading questions from: {json_file}")
        with open(json_file) as file:
            questions = json.load(file)
            # Add source file information for tracking
            json_name = os.path.basename(json_file).replace('.json', '')
            for question in questions:
                question['source_json'] = json_name
            all_questions.extend(questions)
    
    print(f"Total questions loaded from {len(json_files_list)} files: {len(all_questions)}")
    return all_questions


def load_multiple_json_files(gt_file_pattern):
    """Load and merge multiple JSON files based on pattern."""
    json_files = glob.glob(gt_file_pattern)
    if not json_files:
        raise FileNotFoundError(f"No JSON files found matching pattern: {gt_file_pattern}")
    
    all_questions = []
    for json_file in sorted(json_files):
        print(f"Loading questions from: {json_file}")
        with open(json_file) as file:
            questions = json.load(file)
            # Add source file information for tracking
            json_name = os.path.basename(json_file).replace('.json', '')
            for question in questions:
                question['source_json'] = json_name
            all_questions.extend(questions)
    
    print(f"Total questions loaded from {len(json_files)} files: {len(all_questions)}")
    return all_questions


def calculate_scores(results, output_dir):
    """Calculate scores separately for each source JSON."""
    # Group results by source JSON
    results_by_source = {}
    for result in results:
        source = result.get('source_json', 'unknown')
        if source not in results_by_source:
            results_by_source[source] = []
        results_by_source[source].append(result)
    
    # Calculate scores for each source
    scores = {}
    total_correct = 0
    total_questions = 0
    
    for source, source_results in results_by_source.items():
        correct = 0
        for result in source_results:
            model_answer = result['model_answer'].strip().upper()
            correct_answer = result['correct_answer_letter'].strip().upper()
            if model_answer == correct_answer:
                correct += 1
        
        accuracy = correct / len(source_results) if source_results else 0
        scores[source] = {
            'correct': correct,
            'total': len(source_results),
            'accuracy': accuracy
        }
        total_correct += correct
        total_questions += len(source_results)
    
    # Calculate overall score
    overall_accuracy = total_correct / total_questions if total_questions else 0
    scores['overall'] = {
        'correct': total_correct,
        'total': total_questions,
        'accuracy': overall_accuracy
    }
    
    # Save scores to file
    scores_file = os.path.join(output_dir, 'scores.json')
    with open(scores_file, 'w') as f:
        json.dump(scores, f, indent=2)
    
    # Print scores
    print("\n=== Evaluation Scores ===")
    for source, score in scores.items():
        print(f"{source}: {score['correct']}/{score['total']} = {score['accuracy']:.4f}")
    
    return scores


def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks randomly"""

    shuffled = lst[:]  
    random.shuffle(shuffled)
    chunk_size = math.ceil(len(shuffled) / n)
    chunks = [shuffled[i:i+chunk_size] for i in range(0, len(shuffled), chunk_size)]
    
    return chunks


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


class RWKVQwenMLVUEvaluator:
    def __init__(self, model_path):
        """Initialize the RWKV-Qwen hybrid model for MLVU evaluation."""
        self.model_path = model_path
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        print(f"Loading model from: {model_path}")
        
        # Load model and processor
        self.model = AutoModel.from_pretrained(
            model_path, 
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="cuda",
            trust_remote_code=True
        )
        self.model.to(self.device)
        
        self.processor = AutoProcessor.from_pretrained(model_path)
        
        print(f"Model loaded successfully on {self.device}")
        
    def prepare_video_messages(self, video_path, question):
        """Prepare input messages for video question answering."""
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "video",
                        "video": video_path,
                    },
                    {"type": "text", "text": question},
                ],
            }
        ]
        return messages
    
    def generate_response(self, video_path, question, max_new_tokens=1):
        """Generate response for a video question."""
        try:
            # Prepare messages
            messages = self.prepare_video_messages(video_path, question)
            
            # Apply chat template
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            # Process vision info
            image_inputs, video_inputs = process_vision_info(messages)
            
            # Prepare inputs
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            
            # Move to device
            inputs = inputs.to(self.device)
            
            # Generate response
            with torch.no_grad():
                generated_ids = self.model.generate(
                    **inputs, 
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    top_p=0.001,
                    top_k=1,
                    temperature=0.01,
                    repetition_penalty=1.0,
                )
            
            # Decode the response
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            
            return output_text[0].strip()
            
        except Exception as e:
            print(f"Error processing video {video_path}: {e}")
            return f"Error: {str(e)}"


def eval_model(args):
    """
    Run inference on MLVU dataset using the Qwen 2.5 VL model.
    Supports multiple JSON files with separate scoring.
    
    Args:
        args: Command-line arguments.
    """
    # Initialize the model
    evaluator = RWKVQwenMLVUEvaluator(args.model_path)
    
    # Load ground truth file(s)
    if hasattr(args, 'json_files_list') and args.json_files_list:
        # Use the provided list of JSON files
        gt_questions = load_json_files_list(args.json_files_list)
    elif '*' in args.gt_file or '?' in args.gt_file:
        # Multiple JSON files pattern
        gt_questions = load_multiple_json_files(args.gt_file)
    else:
        # Single JSON file
        with open(args.gt_file) as file:
            gt_questions = json.load(file)
        # Add source file information
        json_name = os.path.basename(args.gt_file).replace('.json', '')
        for question in gt_questions:
            question['source_json'] = json_name
    
    print(f"Total questions: {len(gt_questions)}")
    
    # Get chunk for distributed evaluation
    gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
    
    # Create the output directory if it doesn't exist
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    
    # Determine output filename
    if args.num_chunks > 1:
        output_name = f"{args.num_chunks}_{args.chunk_idx}"
    else:
        output_name = args.output_name
    
    answers_file = os.path.join(args.output_dir, f"{output_name}.jsonl")
    
    results = []
    with open(answers_file, "w") as ans_file:
        for sample in tqdm(gt_questions, desc="Processing MLVU videos"):
            # Extract sample information
            video_name = sample['video']
            question = sample['question']
            candidates = sample['candidates']
            correct_answer_text = sample['answer']
            question_type = sample.get('question_type', '')
            duration = sample.get('duration', 0)
            source_json = sample.get('source_json', 'unknown')
            
            # Find the correct answer letter (A, B, C, D, E)
            correct_answer_letter = None
            for i, candidate in enumerate(candidates):
                if candidate.strip() == correct_answer_text.strip():
                    correct_answer_letter = chr(65 + i)  # A, B, C, D, E
                    break
            
            if correct_answer_letter is None:
                print(f"Warning: Could not find correct answer '{correct_answer_text}' in candidates: {candidates}")
                correct_answer_letter = "Unknown"
            
            # Format options as multiple choice
            options_text = ""
            for i, candidate in enumerate(candidates):
                options_text += f"{chr(65 + i)}. {candidate}\n"
            
            # Create prompt with multiple choice format
            option_prompt = "Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, D, or E) of the correct option."
            full_question = question + "\n" + options_text.strip()
            post_prompt = "The best answer is:"
            full_prompt = option_prompt + "\n" + full_question + "\n" + post_prompt
            
            # Use source_json to determine video subdirectory
            video_subdir = os.path.join(args.video_dir, source_json)
            video_path = os.path.join(video_subdir, video_name)
            
            # Check if video file exists
            if not os.path.exists(video_path):
                print(f"Warning: Video file not found: {video_path}")
                model_answer = "Video file not found"
            else:
                # Generate model response
                model_answer = evaluator.generate_response(
                    video_path=video_path,
                    question=full_prompt,
                    max_new_tokens=args.max_new_tokens
                )
                print("model_answer:", model_answer)
            
            # Prepare result
            result = {
                'video': video_name,
                'question': question,
                'candidates': candidates,
                'answer': correct_answer_text,
                'correct_answer_letter': correct_answer_letter,
                'model_answer': model_answer,
                'question_type': question_type,
                'duration': duration,
                'source_json': source_json
            }
            
            results.append(result)
            
            # Write result
            ans_file.write(json.dumps(result) + "\n")
            ans_file.flush()
            
            # Clear cache
            torch.cuda.empty_cache()
    
    # Calculate and save scores
    calculate_scores(results, args.output_dir)
    
    print(f"Evaluation completed. Results saved to: {answers_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MLVU Evaluation with Qwen 2.5 VL Model")
    
    # Define command-line arguments
    parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
    parser.add_argument('--gt_file', help='Path to the ground truth file containing questions. Used if --json_files_list is not provided.', required=False)
    parser.add_argument('--json_files_list', nargs='+', help='List of JSON files containing questions.', required=False)
    parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
    parser.add_argument('--output_name', help='Name of the file for storing results JSON.', default="pred")
    parser.add_argument("--model-path", type=str, required=True, help="Path to the RWKV-Qwen model")
    parser.add_argument("--num-chunks", type=int, default=1, help="Number of chunks for distributed evaluation")
    parser.add_argument("--chunk-idx", type=int, default=0, help="Chunk index for distributed evaluation")
    parser.add_argument("--max-new-tokens", type=int, default=1, help="Maximum number of new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation")
    parser.add_argument("--top-p", type=float, default=0.8, help="Top-p for generation")
    
    args = parser.parse_args()
    
    # Check that either gt_file or json_files_list is provided
    if not args.gt_file and not args.json_files_list:
        parser.error("Either --gt_file or --json_files_list must be provided")
    
    eval_model(args)