import os
import json
import torch
import pandas as pd
import numpy as np
import logging
import multiprocessing as mp
from tqdm import tqdm
from multiprocessing import set_start_method
import time
from datetime import datetime, timedelta
import random
import argparse
from utils import setup_logger, format_time

def merge_results(world_size: int, output_file: str, task: str):
    """
    Merges result files generated by multiple processes into a single output file.

    Args:
        world_size (int): The total number of processes that generated result files.
        output_file (str): The path to the final merged output file.
        task (str): The name of the evaluation task, used to determine the naming
                    pattern of individual process result files.
    """
    with open(output_file, 'w') as outfile:
        for rank in range(world_size):
            # Construct the path to the result file for each process
            process_file = os.path.join(os.path.dirname(output_file), f"{task}_results_rank_{rank}.jsonl")
            if os.path.exists(process_file):
                with open(process_file, 'r') as infile:
                    for line in infile:
                        outfile.write(line)
            else:
                print(f"Warning: Process file {process_file} not found for merging.") # Optional: Add a warning if a file is missing

def prepare_data(task: str) -> tuple[str | list[str], str]:
    """
    Prepares and returns the data file path(s) and video directory for a given task.

    Args:
        task (str): The name of the evaluation task.

    Returns:
        tuple[str | list[str], str]: A tuple containing:
            - The path to the data file (or a list of paths for SPAR-Bench).
            - The path to the directory containing video files.
    """
    if 'VSI-Bench' in task:
        return "/datasets/VSI-Bench/test-00000-of-00001.parquet", "/datasets/VSI-Bench"
    elif task == 'SPBench-SI':
        return "/datasets/SPBench-SI/SPBench-SI.parquet", "/datasets/SPBench-SI/images"
    elif task == 'SPBench-MV':
        return "/datasets/SPBench-MV/SPBench-MV.parquet", "/datasets/SPBench-MV/images"
    elif task == 'SPAR-Bench':
        spar_dir = "/datasets/SPAR-Bench"
        return [os.path.join(spar_dir, parquet_file) for parquet_file in os.listdir(spar_dir) if parquet_file.endswith('.parquet')], spar_dir
    elif task == 'ViewSpatial-Bench':
        return "/datasets/ViewSpatial-Bench/ViewSpatial-Bench.json", "/datasets/ViewSpatial-Bench"
    elif task == 'CV-Bench':
        return ["/datasets/CV-Bench/test_2d.parquet", "/datasets/CV-Bench/test_3d.parquet"], ""
    else:
        raise ValueError(f"Task {task} not recognized for data preparation.")

# List of supported evaluation benchmark tasks
SUPPORTED_TASK=['VSI-Bench', 'SPBench-SI', 'SPBench-MV', 'SPAR-Bench', 'ViewSpatial-Bench', 'CV-Bench']

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_task", default="")
    parser.add_argument("--log_dir", default="")
    parser.add_argument("--model_config", default="")
    parser.add_argument("--gpu_ids", default="")
    parser.add_argument("--num_processes", type=int, default="")
    parser.add_argument("--prompt_type", default="")
    parser.add_argument("--num_frames", type=int, default=32)
    parser.add_argument("--max_pixels", type=int, default=528*28*28)
    parser.add_argument("--min_pixels", type=int, default=16*28*28)
    parser.add_argument("--debug_mode", action="store_true")
    parser.add_argument("--debug_size", type=int, default=4)
    parser.add_argument("--batch_size", type=int, default=1)
    args = parser.parse_args()
    
    # Set the multiprocessing start method to 'spawn'.
    set_start_method('spawn', force=True)

    main_start_time = time.time() # Record the start time of the main script

    # --- Configuration ---
    eval_task = args.eval_task
    
    if eval_task not in SUPPORTED_TASK:
        print(f"Error: Task '{eval_task}' is not supported. Supported tasks are: {SUPPORTED_TASK}")
        exit()

    data_file, video_dir = prepare_data(eval_task)

    # Create a timestamped directory for this specific run's outputs
    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(args.log_dir, eval_task, timestamp_str)
    os.makedirs(output_dir, exist_ok=True)

    # Path for the final merged JSONL results file
    output_jsonl_file = os.path.join(output_dir, f"{eval_task}_results.jsonl")
    # Path for the main evaluation log file
    log_output_file = os.path.join(output_dir, f"{eval_task}_eval.log")

    # Collect parameters to log
    params_to_log = {
        "model_config": args.model_config,
        "eval_task": eval_task,
        "data_file": data_file,
        "video_dir": video_dir,
        "num_frames": args.num_frames,
        "max_pixels": args.max_pixels,
        "min_pixels": args.min_pixels,
        "debug_mode": args.debug_mode,
        "batch_size": args.batch_size,
        "debug_size": args.debug_size if args.debug_mode else "N/A",
        "gpu_ids": args.gpu_ids,
        "num_processes": args.num_processes,    
        "prompt_type": args.prompt_type,
        "output_dir": output_dir,
    }

    # Setup main logger for the script
    main_logger = setup_logger(0, log_output_file, params_to_log)
    main_logger.info("Main script started. Configuration logged.")

    process_runtimes = [] # To store runtime of each process
        
    # --- Run Evaluation ---
    # Dispatch to task-specific evaluation functions, either in parallel or sequentially
    num_processes = args.num_processes
    
    # Force single process mode when using vLLM to avoid multiprocessing conflicts
    if num_processes > 1:
        main_logger.info(f"Starting evaluation of {eval_task} with {num_processes} processes.")
        with mp.Pool(processes=args.num_processes) as pool:
            args_list = [
                (rank, num_processes, data_file, video_dir, args.model_config, output_dir, log_output_file,
                args.gpu_ids, args.num_frames, args.max_pixels, args.min_pixels, args.debug_mode, args.batch_size,
                args.debug_size, params_to_log, args.prompt_type)
                for rank in range(num_processes)
            ]
            if eval_task == 'VSI-Bench':
                from data_utils.vsibench import evaluate_vsibench, vsibench_eval
                results = pool.starmap(evaluate_vsibench, args_list)
            elif eval_task in ['SPBench-SI', 'SPBench-MV']:
                from data_utils.spbench import evaluate_spbench
                from data_utils.vsibench import vsibench_eval
                results = pool.starmap(evaluate_spbench, args_list)
            elif eval_task == 'SPAR-Bench':
                from data_utils.sparbench import evaluate_sparbench, sparbench_eval
                results = pool.starmap(evaluate_sparbench, args_list)
            elif eval_task == 'ViewSpatial-Bench':
                from data_utils.viewspatial_bench import evaluate_viewspatial_bench, viewspatial_bench_eval
                results = pool.starmap(evaluate_viewspatial_bench, args_list)
            elif eval_task == 'CV-Bench':
                from data_utils.cvbench import evaluate_cvbench, cvbench_eval
                results = pool.starmap(evaluate_cvbench, args_list)
            else:
                main_logger.error(f"Task '{eval_task}' not recognized for multiprocessing dispatch.")
                exit()
                
            process_runtimes = [res[1] for res in results if isinstance(res, tuple) and len(res) == 2]
        # Merge the .jsonl files produced by each process
        merge_results(num_processes, output_jsonl_file, eval_task)
        main_logger.info(f"Results from {num_processes} processes merged into {output_jsonl_file}")
        
    else:
        # Single process execution
        main_logger.info(f"Starting evaluation of {eval_task} with a single process.")
        common_args = (0, 1, data_file, video_dir, args.model_config, output_dir, log_output_file,
                       args.gpu_ids, args.num_frames, args.max_pixels, args.min_pixels, args.debug_mode, args.batch_size,
                       args.debug_size, params_to_log, args.prompt_type)
        process_output_file, elapsed_time_process = None, 0

        if eval_task == 'VSI-Bench':
            from data_utils.vsibench import evaluate_vsibench, vsibench_eval
            process_output_file, elapsed_time_process = evaluate_vsibench(*common_args)
        elif eval_task in ['SPBench-SI', 'SPBench-MV']:
            from data_utils.spbench import evaluate_spbench
            from data_utils.vsibench import vsibench_eval
            process_output_file, elapsed_time_process = evaluate_spbench(*common_args)
        elif eval_task == 'SPAR-Bench':
            from data_utils.sparbench import evaluate_sparbench, sparbench_eval
            process_output_file, elapsed_time_process = evaluate_sparbench(*common_args)
        elif eval_task == 'ViewSpatial-Bench':
            from data_utils.viewspatial_bench import evaluate_viewspatial_bench, viewspatial_bench_eval
            process_output_file, elapsed_time_process = evaluate_viewspatial_bench(*common_args)
        elif eval_task == 'CV-Bench':
            from data_utils.cvbench import evaluate_cvbench, cvbench_eval
            process_output_file, elapsed_time_process = evaluate_cvbench(*common_args)
        else:
            main_logger.error(f"Task '{eval_task}' not recognized for single process dispatch.")
            exit()

        process_runtimes = [elapsed_time_process]
        if process_output_file and os.path.exists(process_output_file):
            os.rename(process_output_file, output_jsonl_file)
            main_logger.info(f"Single process result saved to {output_jsonl_file}")
        else:
            main_logger.error(f"Single process output file {process_output_file} not found or not generated.")


    main_end_time = time.time()
    main_elapsed_time = main_end_time - main_start_time
    max_process_runtime = max(process_runtimes) if process_runtimes else 0

    main_logger.info(f"All evaluation tasks of {eval_task} completed. Final results are in: {output_jsonl_file}")
    main_logger.info(f"Maximum individual process runtime: {format_time(max_process_runtime)}")
    main_logger.info(f"Total script runtime: {format_time(main_elapsed_time)}")

    print(f"All evaluation tasks of {eval_task} completed. Final results are in: {output_jsonl_file}")
    print(f"Total script runtime: {format_time(main_elapsed_time)}")

    # --- Perform Final Evaluation Scoring ---
    # This section calculates metrics based on the generated results file.
    evaluation_results = {}
    log_str = "" # Initialize log string for evaluation results

    main_logger.info(f"Starting final scoring for {eval_task} using results from {output_jsonl_file}")

    if eval_task in ["VSI-Bench", "SPBench-SI", "SPBench-MV"]: 
        evaluation_results = vsibench_eval(output_jsonl_file, args.prompt_type)
        print(f"{eval_task} Evaluation Results:", evaluation_results)
        log_str = f"{eval_task} Evaluation Complete. Results file: {output_jsonl_file}\n"
    elif eval_task == "SPAR-Bench":
        evaluation_results = sparbench_eval(output_jsonl_file, args.prompt_type)
        print("SPAR-Bench Evaluation Results:", evaluation_results)
        log_str = f"SPAR-Bench Evaluation Complete. Results file: {output_jsonl_file}\n"
    elif eval_task == "ViewSpatial-Bench":
        evaluation_results = viewspatial_bench_eval(output_jsonl_file, args.prompt_type)
        print("ViewSpatial-Bench Evaluation Results:", evaluation_results)
        log_str = f"ViewSpatial-Bench Evaluation Complete. Results file: {output_jsonl_file}\n"
        log_str += f"Overall Accuracy: {evaluation_results.get('overall_accuracy', 0.0) * 100.:.2f}%\n"
    elif eval_task == "CV-Bench":
        evaluation_results = cvbench_eval(output_jsonl_file, args.prompt_type)
        print("CV-Bench Evaluation Results:", evaluation_results)
        log_str = f"CV-Bench Evaluation Complete. Results file: {output_jsonl_file}\n"

    if evaluation_results: # If any evaluation was performed
        print(log_str)
        main_logger.info("--- Final Evaluation Metrics ---")
        main_logger.info(log_str)
        main_logger.info(f"Full Evaluation Metrics Dictionary: {json.dumps(evaluation_results, indent=2)}")
    else:
        main_logger.info(f"No final evaluation metrics calculated for {eval_task} or evaluation failed.")