"""
This script evaluates the responses of the model by comparing the generated turtle graphics code with ground truth.

Usage:
    python eval_responses.py \
        --dataset_file <dataset_file> \
        --prompt_file <prompt_file> \
        --response_file <response_file> \
        --evaluation_file <evaluation_file>

Example:
    python eval_responses.py \
        --dataset_file data/graphics_dataset.json \
        --prompt_file exps/eval_vlms/results/prompts/prompts__dataset_graphics_sz823__pixtralhf.json \
        --response_file exps/eval_vlms/results/responses/responses__dataset_graphics_sz823__pixtralhf_12b.json \
        --evaluation_file exps/eval_vlms/results/evaluations/evaluations__dataset_graphics_sz823__pixtralhf_12b.json
"""

import json
import argparse
from src.turtlegfx.eval.lines import compare_lines_from_code, get_normalized_images_from_code
from src.turtlegfx.emulate.emulator import Executor
from src.turtlegfx_datagen.utils.extract_python import extract_python_code_from_text
from tqdm import tqdm
import os
import numpy as np
from pebble import ProcessPool
from concurrent.futures import TimeoutError
from src.turtlegfx.eval.embeddings_resnet18 import compare_embeddings_from_batch_images

def process_response_symbol(data):
    """Process a single response by executing and comparing the model output with ground truth.
    
    Args:
        data (tuple): A tuple containing (prompt_data, response_data)
            - prompt_data (dict): Contains ground truth data including 'id', 'code', and 'task_image'
            - response_data (dict): Contains model response data including 'id' and model output
    
    Returns:
        dict: Result dictionary containing:
            - prompt_id (str): ID of the prompt
            - response_id (str): ID of the response
            - model_output (str): Raw model output
            - code_true (str): Ground truth code (if execution successful)
            - image_output (list): List of generated images (if execution successful)
            - image_true (Image): Ground truth image (if execution successful)
            - exec_result (list): List of execution result status and details
            - success (bool): Whether the generated code matches ground truth
            - tested_model_codes (list): List of all tested codes
    """
    dataset_sample, prompt_data, response_data = data
    model_output = response_data["choices"][0]["message"]["content"]
    
    model_codes = extract_python_code_from_text(model_output)

    meta_info = {
            "id": dataset_sample["id"],
            "source": dataset_sample["source"],
            "category": dataset_sample["category"],
            "difficulty": dataset_sample["difficulty"],
            "prompt_id": prompt_data["id"],
            "response_id": response_data["id"],
            "model_output": model_output,
            "code_true": prompt_data["code"],
            "image_true": prompt_data["task_image"],
    }

    tested_model_codes = []
    tested_exec_results = []
    tested_image_outputs = []

    for model_code in model_codes:
        img, res = Executor().run(model_code, show_screen=False)

        # Only track successful executions
        if res['status'] == 'success':
            tested_model_codes.append(model_code)
            tested_exec_results.append(res)
            tested_image_outputs.append(img)

            # Check if the code is correct
            the_same = compare_lines_from_code(model_code, prompt_data["code"], show_screen=False)
            if the_same:
                return {
                    **meta_info,
                    "image_output": tested_image_outputs,
                    "exec_result": tested_exec_results,
                    "success": True,
                    "tested_model_codes": tested_model_codes
                }

    return {
        **meta_info,
        "image_output": tested_image_outputs,
        "exec_result": tested_exec_results,
        "success": False,
        "tested_model_codes": tested_model_codes
    }

def process_responses_symbol(dataset, prompts, responses, num_workers, timeout=60):
    """Process multiple responses, using parallel processing only if num_workers > 1.
    
    Args:
        dataset (dict): Dataset information
        prompts (list): List of prompt dictionaries containing ground truth data
        responses (list): List of response dictionaries containing model outputs
        num_workers (int): Number of parallel worker processes. If 1, runs sequentially.
        timeout (int, optional): Maximum execution time per response in seconds. Defaults to 10.
    
    Returns:
        list: List of result dictionaries, one for each prompt-response pair
    """
    results = []
    
    # Sequential processing when num_workers=1
    if num_workers == 1:
        for dataset_sample, prompt_data, response_data in tqdm(zip(dataset, prompts, responses), total=len(prompts), desc="Processing responses"):
            try:
                result = process_response_symbol((dataset_sample, prompt_data, response_data))
                results.append(result)
            except Exception as e:
                print(f"Error processing prompt {prompt_data['id']}: {e}")
                print(f"Model output: {response_data['choices'][0]['message']['content']}\n\n")
                results.append({
                    "id": dataset_sample["id"],
                    "prompt_id": prompt_data["id"],
                    "response_id": response_data["id"],
                    # dataset information
                    "source": dataset_sample["source"],
                    "category": dataset_sample["category"],
                    "difficulty": dataset_sample["difficulty"],
                    # response information
                    "model_output": response_data["choices"][0]["message"]["content"],
                    "code_true": prompt_data["code"],
                    "image_output": [],
                    "image_true": prompt_data["task_image"],
                    "exec_result": [{
                        "status": "fail",
                        "error_type": "ERROR_UNKNOWN_EXCEPTION",
                        "message": str(e)
                    }],
                    "success": False,
                    "tested_model_codes": []
                })
        return results
    
    # Parallel processing when num_workers > 1
    with ProcessPool(max_workers=num_workers, max_tasks=8) as pool:
        future_to_data = {}
        # Schedule all tasks
        for dataset_sample, prompt_data, response_data in zip(dataset, prompts, responses):
            future = pool.schedule(
                process_response_symbol, 
                args=((dataset_sample, prompt_data, response_data),), 
                timeout=timeout
            )
            future_to_data[future] = (dataset_sample, prompt_data, response_data)

        # Process results with progress bar
        for future in tqdm(future_to_data, total=len(future_to_data), desc="Processing responses"):
            dataset_sample, prompt_data, response_data = future_to_data[future]
            try:
                result = future.result()
                results.append(result)
            except TimeoutError:
                results.append({
                    "id": dataset_sample["id"],
                    "prompt_id": prompt_data["id"],
                    "response_id": response_data["id"],
                    # dataset information
                    "source": dataset_sample["source"],
                    "category": dataset_sample["category"],
                    "difficulty": dataset_sample["difficulty"],
                    # response information
                    "model_output": response_data["choices"][0]["message"]["content"],
                    "code_true": prompt_data["code"],
                    "image_output": [],
                    "image_true": prompt_data["task_image"],
                    "exec_result": [{
                        "status": "fail",
                        "error_type": "ERROR_EXEC_TIMEOUT",
                        "message": f"Execution timed out after {timeout} seconds."
                    }],
                    "success": False,
                    "tested_model_codes": []
                })
            except Exception as e:
                print(f"compare_lines_from_code Error: {e}")
                print(f"Model output: {response_data['choices'][0]['message']['content']}\n\n")
                results.append({
                    "id": dataset_sample["id"],
                    "prompt_id": prompt_data["id"],
                    "response_id": response_data["id"],
                    # dataset information
                    "source": dataset_sample["source"],
                    "category": dataset_sample["category"],
                    "difficulty": dataset_sample["difficulty"],
                    # response information
                    "model_output": response_data["choices"][0]["message"]["content"],
                    "code_true": prompt_data["code"],
                    "image_output": [],
                    "image_true": prompt_data["task_image"],
                    "exec_result": [{
                        "status": "fail",
                        "error_type": "ERROR_UNKNOWN_EXCEPTION",
                        "message": str(e)
                    }],
                    "success": False,
                    "tested_model_codes": []
                })
    
    return results

def process_single_batch_result(result):
    """Process a single result to get normalized images."""
    if "image_output" in result \
        and len(result["image_output"]) > 0 \
        and 'image_true' in result \
        and result['image_true'] is not None:

        model_codes = extract_python_code_from_text(result["model_output"])
        # choose the code with the longest length
        model_code = max(model_codes, key=len)
        img_true, img_output = get_normalized_images_from_code(result["code_true"], model_code)
        if img_true is not None and img_output is not None:
            return (img_true, img_output)
    return None

def process_responses_embedding(results, batch_size=128):
    """Calculate embedding-based similarity scores between generated and ground truth images.
    
    This function processes results in batches, computing ResNet18-based embedding scores
    for successfully generated images compared to their ground truth counterparts.
    
    Args:
        results (list): List of result dictionaries from process_responses_symbol
        batch_size (int, optional): Number of images to process in each batch. Defaults to 32.
    
    Returns:
        list: Updated results list with additional 'embedding_score' field for each result
            - embedding_score (float): Euclidean distance between image embeddings
                                     (0 for failed executions)
    """
    n_batches = (len(results) + batch_size - 1) // batch_size
    
    for i in tqdm(range(n_batches), desc="Computing embedding scores"):
        batch_start = i * batch_size
        batch_end = min((i + 1) * batch_size, len(results))
        batch_results = results[batch_start:batch_end]
        
        # Process batch in parallel
        batch_outputs = []
        batch_trues = []
        valid_indices = []
        
        with ProcessPool(max_workers=8) as pool:
            future_to_idx = {
                pool.schedule(process_single_batch_result, args=(result,), timeout=60): idx
                for idx, result in enumerate(batch_results)
            }
            
            for future in future_to_idx:
                idx = future_to_idx[future]
                try:
                    result = future.result()
                    if result is not None:
                        img_true, img_output = result
                        batch_trues.append(img_true)
                        batch_outputs.append(img_output)
                        valid_indices.append(idx)
                except (TimeoutError, Exception) as e:
                    continue
        
        if valid_indices:
            # Calculate embedding scores for the batch
            embedding_scores = compare_embeddings_from_batch_images(
                batch_outputs,
                batch_trues,
                metric="euclidean",
                batch_size=min(batch_size, len(valid_indices))
            )

            # Update results with embedding scores
            for idx, score in zip(valid_indices, embedding_scores):
                results[batch_start + idx]["embedding_score"] = score
            
        # set embedding_score to 0 for the rest of the batch
        no_embedding_indices = set(range(len(batch_results))) - set(valid_indices)
        for idx in no_embedding_indices:
            results[batch_start + idx]["embedding_score"] = 0
    
    return results

def calculate_statistics(results):
    """Calculate comprehensive statistics from evaluation results for each dataset type.
    
    Args:
        results (list): List of result dictionaries with execution and embedding information
    
    Returns:
        dict: Dictionary containing statistics for each dataset type ('midi', 'handdrawn', 'synthetic', 'all'):
            Each dataset type contains:
                - overall_stats (dict):
                    - n_total (int): Total number of evaluated responses
                    - n_success (int): Number of successful code generations
                    - success_rate (float): Success rate as decimal
                    - stats_num (dict): Count of each error type
                    - stats_dist (dict): Distribution (percentage) of each error type
                    - stats_embedding (dict): Embedding score statistics
                - difficulty_stats (dict): Statistics grouped by difficulty level
                - category_stats (dict): Statistics grouped by category
    """
    def calculate_group_statistics(group_results):
        """Helper function to calculate statistics for a group of results."""
        n_response = len(group_results)
        
        # Calculate error type counts - modified to handle list of exec_results
        type_counts = {}
        for result in group_results:
            # Take the last exec_result as the final status
            error_type = result["exec_result"][-1]["error_type"] if result["exec_result"] else "NO_EXECUTION"
            type_counts[error_type] = type_counts.get(error_type, 0) + 1
        
        # Calculate embedding statistics
        embedding_scores = [x['embedding_score'] for x in group_results]
        
        # Base statistics
        embedding_stats = {
            "n_embeddings": len(embedding_scores),
            "metric": "euclidean",
            "mean": round(float(np.mean(embedding_scores)), 4),
            "std": round(float(np.std(embedding_scores)), 4),
            "max": round(float(np.max(embedding_scores)), 4),
            "min": round(float(np.min(embedding_scores)), 4),
            "embeddings_eq_0": round(len([x for x in embedding_scores if x == 0]) / len(embedding_scores), 4),
            "embeddings_neq_0": round(len([x for x in embedding_scores if x > 0]) / len(embedding_scores), 4),
        }

        embedding_stats.update({
            f"embeddings_geq_{value}": round(len([x for x in embedding_scores if x >= value]) / len(embedding_scores), 4) 
            for value in [i/100 for i in range(101)]})
        
        return {
            "n_total": n_response,
            "n_success": sum(result["success"] for result in group_results),
            "success_rate": round(sum(result["success"] for result in group_results) / n_response, 4),
            "stats_num": type_counts,
            "stats_dist": {error_type: round(count / n_response, 4) 
                          for error_type, count in type_counts.items()},
            "stats_embedding": embedding_stats
        }

    # Group results by dataset type
    results_groups = {
        'midi': [],
        'handdrawn': [],
        'synthetic': [],
        'all': []
    }
    
    for result in results:
        # Add result to all
        results_groups['all'].append(result)

        if result['source'] == 'midi':
            results_groups['midi'].append(result)
        elif result['source'] == 'handdrawn':
            results_groups['handdrawn'].append(result)
        elif result['source'] == 'synthetic':
            results_groups['synthetic'].append(result)
        else:
            raise ValueError(f"Unknown dataset type for prompt_id: {result['prompt_id']}")
    
    # Calculate statistics for each dataset type
    stats = {}
    for dataset_type, group_results in results_groups.items():
        if not group_results:  # Skip empty groups
            print(f"Warning: No results found for dataset type: {dataset_type}")
            continue
        
        # Calculate overall statistics
        overall_stats = calculate_group_statistics(group_results)
        
        # Calculate statistics by difficulty
        difficulty_stats = {}
        for difficulty in set(result['difficulty'] for result in group_results):
            difficulty_results = [result for result in group_results if result['difficulty'] == difficulty]
            difficulty_stats[difficulty] = calculate_group_statistics(difficulty_results)
        
        # Calculate statistics by category
        category_stats = {}
        for category in set(result['category'] for result in group_results):
            category_results = [result for result in group_results if result['category'] == category]
            category_stats[category] = calculate_group_statistics(category_results)
        
        stats[dataset_type] = {
            "overall_stats": overall_stats,
            "difficulty_stats": difficulty_stats,
            "category_stats": category_stats
        }
    
    return stats

def align_prompts_and_responses(dataset, prompts, responses):
    """Align responses with prompts based on their IDs.
    
    Args:
        prompts (list): List of prompt dictionaries (order will be preserved)
        responses (list): List of response dictionaries (will be reordered)
    
    Returns:
        tuple: (prompts, aligned_responses) where aligned_responses matches prompts order
    """
    # Create a mapping of response IDs to responses
    response_map = {r['id']: r for r in responses}
    prompt_map = {p['id']: p for p in prompts}
    
    # Create aligned responses list matching dataset order
    aligned_responses = []
    aligned_prompts = []
    for dataset_sample in dataset:
        matching_response = response_map.get(dataset_sample['id'])
        matching_prompt = prompt_map.get(dataset_sample['id'])
        if matching_response is None:
            raise ValueError(f"No matching response found for prompt ID: {dataset_sample['id']}")
        aligned_responses.append(matching_response)
        aligned_prompts.append(matching_prompt)
    
    # Verify alignment
    assert all(x['id'] == y['id'] for x, y in zip(dataset, aligned_prompts))
    assert all(x['id'] == y['id'] for x, y in zip(dataset, aligned_responses))
    
    return aligned_prompts, aligned_responses


def eval_responses(dataset_file, prompt_file, response_file, output_file, num_workers, use_embedding=False):
    """Main evaluation pipeline for comparing model-generated code with ground truth.
    
    This function orchestrates the entire evaluation process:
    1. Load prompt and response data
    2. Do symbolic evaluation of code correctness
    3. Do embedding-based similarity scores
    4. Calculate comprehensive statistics
    5. Save results to disk
    
    Args:
        dataset_file (str): Path to JSON file containing dataset information
        prompt_file (str): Path to JSON file containing ground truth prompts and code
        response_file (str): Path to JSON file containing model responses
        output_file (str): Path where evaluation results will be saved
        num_workers (int): Number of parallel workers for processing
        use_embedding (bool): Whether to run embedding-based evaluation
    """
    # Load data
    dataset = json.load(open(dataset_file, "r"))
    prompts = json.load(open(prompt_file, "r"))
    responses = json.load(open(response_file, "r"))

    # align the prompts and responses 
    # This is for openai api, the order of the responses may not be the same as the prompts
    if '__gpt' in prompt_file:
        prompts, responses = align_prompts_and_responses(dataset, prompts, responses)

    print(f"Evaluating {len(prompts)} responses...")
    print(f"Dataset file: {dataset_file}")
    print(f"Prompt file: {prompt_file}")
    print(f"Response file: {response_file}")
    print(f"Output file: {output_file}\n")

    # Part 1: Compute success rate (symbolic evaluation)
    results = process_responses_symbol(dataset, prompts, responses, num_workers)

    # print the success rate
    n_success = sum(result["success"] for result in results)
    print(f"Success rate: {n_success / len(results)}")

    # Part 2: Compute embedding scores (optional)
    if use_embedding:
        print("Running embedding-based evaluation...")
        results = process_responses_embedding(results)
    else:
        print("Skipping embedding-based evaluation...")
        # Add default embedding score of 0 to maintain result structure
        for result in results:
            result["embedding_score"] = 0
    
    # Calculate statistics
    statistics = calculate_statistics(results)
    
    # Prepare final results
    eval_results = {
        "dataset_file": dataset_file,
        "prompt_file": prompt_file,
        "response_file": response_file,
        "evaluations": results,
        **statistics
    }
    
    print(f"Evaluation done. Results saved to {output_file}")
    print(f"#response: {len(prompts)}")
    
    # Print statistics for each dataset group
    for dataset_type in ['midi', 'handdrawn', 'synthetic', 'all']:
        if dataset_type in statistics:
            print(f"\n=== {dataset_type.upper()} Statistics ===")
            print(f"#samples: {statistics[dataset_type]['overall_stats']['n_total']}")
            print(f"#success: {statistics[dataset_type]['overall_stats']['n_success']}")
            print(f"Success rate: {statistics[dataset_type]['overall_stats']['n_success'] / statistics[dataset_type]['overall_stats']['n_total']:.4f}")
            print(f"Distribution: {statistics[dataset_type]['overall_stats']['stats_dist']}")
            print(f"Embedding stats: {statistics[dataset_type]['overall_stats']['stats_embedding']}")
        
    # Save results
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    json.dump(eval_results, open(output_file, "w"), indent=2)

def main(args):
    """Entry point for command-line execution of the evaluation pipeline.
    
    Args:
        args (Namespace): Command line arguments containing:
            - dataset_file (str): Path to dataset JSON file
            - prompt_file (str): Path to prompts JSON file
            - response_file (str): Path to responses JSON file
            - output_file (str): Path for saving results
            - num_workers (int): Number of parallel workers
            - use_embedding (bool): Whether to run embedding-based evaluation
    """
    eval_responses(args.dataset_file, args.prompt_file, args.response_file, args.output_file, args.num_workers, args.use_embedding)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_file", type=str, required=True)
    parser.add_argument("--prompt_file", type=str, required=True)
    parser.add_argument("--response_file", type=str, required=True)
    parser.add_argument("--output_file", type=str, required=True)
    parser.add_argument("--num_workers", type=int, default=os.cpu_count())
    parser.add_argument("--use_embedding", action="store_true", default=False)
    args = parser.parse_args()
    main(args)
