#!/usr/bin/env python3


import argparse
import json
import math
import os
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
from datasets import load_dataset, load_from_disk
from scipy.special import comb
from tqdm import tqdm
from vllm import LLM, SamplingParams

from rewards import DockerTestReward, DockerIOTestReward


class DateTimeEncoder(json.JSONEncoder):
    """Custom JSON encoder to handle datetime objects."""

    def default(self, obj):
        if isinstance(obj, datetime):
            return obj.isoformat()
        return super().default(obj)


def process_completion(compl: str) -> str:
    if "```py" in compl:
        return compl
    compl_lines = compl.splitlines(keepends=True)
    stop_idx = 0
    for _l in compl_lines:
        # if _l.lstrip().startswith("return"):
        #     break
        if _l.lstrip() == _l:
            if not _l.startswith('#'):
                break
        stop_idx += 1
    return ''.join(compl_lines[:stop_idx])


@dataclass
class EvaluationConfig:
    """Configuration for pass@k evaluation."""
    model_name: str
    dataset_path: str
    dataset_split: str = "validation"
    k: int = 1024
    output_dir: str = "./eval_results"
    temperature: float = 0.8
    max_tokens: int = 512
    top_p: float = 0.95
    gpu_memory_utilization: float = 0.9
    tensor_parallel_size: int = 1  # Number of GPUs to use for tensor parallelism
    max_model_len: int = 4096
    trust_remote_code: bool = True
    prompt_column: str = "prompt"
    test_column: str = "test_list"
    batch_size: int = 1
    visible_gpus: str = None  # e.g., "0,1,2,3" or None to use all available GPUs
    format: str = "completion"


def calculate_pass_at_k(n: int, c: int, k: int) -> float:
    """
    Calculate pass@k metric.
    
    Args:
        n: total number of samples
        c: number of correct samples  
        k: k in pass@k
        
    Returns:
        pass@k score
    """
    if n - c < k:
        return 1.0
    return 1.0 - math.comb(n - c, k) / math.comb(n, k)


def load_evaluation_dataset(dataset_path: str, split: str = "validation", prompt_column: str = "prompt",
                            test_column: str = "test_list") -> List[Dict]:
    """Load dataset for evaluation."""
    if os.path.exists(dataset_path):
        dataset = load_from_disk(dataset_path)
    else:
        # Try loading from HuggingFace
        dataset = load_dataset(dataset_path, force_download=True)

    eval_dataset = dataset[split]
    eval_dataset = eval_dataset.map(
        lambda x: {"metadata": {k: v for k, v in x.items() if k not in [prompt_column, test_column]}})
    eval_dataset = eval_dataset.select_columns([prompt_column, test_column, "metadata"])

    if prompt_column != "prompt":
        eval_dataset = eval_dataset.rename_column(prompt_column, "prompt")
    if test_column != "test_list":
        eval_dataset = eval_dataset.rename_column(test_column, "test_list")
    eval_dataset = eval_dataset.filter(lambda x: x["prompt"].strip() and len(x["test_list"]) > 0)
    eval_dataset = eval_dataset.map(lambda x, idx: {"id": idx}, with_indices=True)
    problems = eval_dataset.to_list()

    return problems


def initialize_vllm_model(config: EvaluationConfig) -> LLM:
    """Initialize vLLM model."""
    print(f"Loading model {config.model_name} with tensor_parallel_size={config.tensor_parallel_size}...")

    # Set visible GPUs if specified
    if config.visible_gpus is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = config.visible_gpus
        print(f"Using GPUs: {config.visible_gpus}")
    else:
        print("Using all available GPUs")

    # vLLM will use the specified number of GPUs for tensor parallelism
    try:
        model = LLM(
            model=config.model_name,
            tensor_parallel_size=config.tensor_parallel_size,
            gpu_memory_utilization=config.gpu_memory_utilization,
            max_model_len=config.max_model_len,
            trust_remote_code=config.trust_remote_code,
            dtype=torch.bfloat16
        )
        print(f"Successfully loaded model: {config.model_name}")
        return model
    except Exception as e:
        print(f"Error loading model {config.model_name}: {e}")
        raise


def generate_completions(model: LLM, prompts: List[str], config: EvaluationConfig) -> List[List[str]]:
    """Generate k completions for each prompt."""
    sampling_params = SamplingParams(
        temperature=config.temperature,
        top_p=config.top_p,
        max_tokens=config.max_tokens,
        n=config.k  # Generate k samples per prompt
    )

    print(f"Generating {config.k} completions for {len(prompts)} prompts...")
    outputs = model.generate(prompts, sampling_params)

    completions = []
    for output in outputs:
        # Extract all k completions for this prompt
        prompt_completions = [process_completion(completion.text) for completion in output.outputs]
        completions.append(prompt_completions)

    return completions


def score_completions(completions: List[List[str]], problems: List[Dict], config: EvaluationConfig) -> List[
    List[float]]:
    """Score all completions."""

    if config.format == "completion":
        reward_function = DockerTestReward(code_extractor_fn_name='default_completion',
                                           run_tests_separately=True)
    elif config.format == "io":
        reward_function = DockerIOTestReward(code_extractor_fn_name='default_chat',
                                             run_tests_separately=True)
    else:
        raise ValueError(f"Invalid format: {config.format}")

    print(f"Scoring completions using {reward_function.__name__}...")

    all_scores = []

    for i, (problem_completions, problem) in enumerate(tqdm(zip(completions, problems), desc="Scoring")):
        problem_scores = []

        for completion in problem_completions:
            try:
                # Score this completion
                score = reward_function.get_reward_for_one(
                    prompt=problem['prompt'],
                    completion=completion,
                    test_list=problem['test_list']
                )
                problem_scores.append(score)
            except Exception as e:
                print(f"Error scoring completion for problem {i}: {e}")
                problem_scores.append(0.0)

        all_scores.append(problem_scores)

    return all_scores


def calculate_pass_at_k_metrics(scores: List[List[float]]) -> Dict[str, float]:
    """Calculate pass@k for different k values."""
    k_values = [2 ** i for i in range(11)]

    results = {}

    for k in k_values:
        if k > len(scores[0]):
            continue

        pass_at_k_scores = []
        for problem_scores in scores:
            n = len(problem_scores)
            c = sum(1 for score in problem_scores if score >= 0.9999)  # Consider >0.9999 as passing

            if k <= n:
                pass_at_k_score = calculate_pass_at_k(n, c, k)
                pass_at_k_scores.append(pass_at_k_score)

        results[f'pass@{k}'] = sum(pass_at_k_scores) / len(pass_at_k_scores) if pass_at_k_scores else 0.0

    return results


def calculate_max_at_k(scores: List[List[float]]) -> Dict[str, float]:
    """Calculate max@k for different k values."""
    k_values = [2 ** i for i in range(11)]

    results = {}
    scores = np.array(scores)
    sorted_scores = np.sort(scores, axis=1)

    for k in k_values:
        if k > len(scores[0]):
            continue

        weights = comb(np.arange(scores.shape[1]), k - 1) / comb(scores.shape[1], k)
        max_at_k_score = weights @ sorted_scores.T

        results[f'max@{k}'] = max_at_k_score.mean()

    return results


def save_detailed_results(problems: List[Dict], completions: List[List[str]],
                          scores: List[List[float]], output_dir: str, config: EvaluationConfig,
                          metrics: Dict[str, float]):
    """Save detailed results for each problem."""
    results_dir = Path(output_dir)
    results_dir.mkdir(parents=True, exist_ok=True)

    # Save individual problem results
    problems_dir = results_dir / "problems"
    problems_dir.mkdir(exist_ok=True)

    for i, (problem, problem_completions, problem_scores) in enumerate(zip(problems, completions, scores)):
        problem_result = {
            'problem_id': problem['id'],
            'prompt': problem['prompt'],
            'test_list': problem['test_list'],
            'metadata': problem['metadata'],
            'completions': [
                {
                    'completion': completion,
                    'score': score,
                    'passed': score > 0.99
                }
                for completion, score in zip(problem_completions, problem_scores)
            ],
            'num_passed': sum(1 for score in problem_scores if score > 0.99),
            'total_completions': len(problem_scores),
            'best_score': max(problem_scores),
            'average_score': sum(problem_scores) / len(problem_scores)
        }

        with open(problems_dir / f"problem_{i:04d}.json", 'w') as f:
            json.dump(problem_result, f, indent=2, cls=DateTimeEncoder)

    # Save summary results
    if metrics is None:
        metrics = {}
        metrics['pass_at_k'] = calculate_pass_at_k_metrics(scores)

    summary = {
        'config': {
            'model_name': config.model_name,
            'dataset_path': config.dataset_path,
            'dataset_split': config.dataset_split,
            'k': config.k,
            'temperature': config.temperature,
            'max_tokens': config.max_tokens,
            'tensor_parallel_size': config.tensor_parallel_size,
            'visible_gpus': config.visible_gpus
        },
        'pass@k': metrics['pass_at_k'],
        'max@k': metrics['max_at_k'],
        'summary_stats': {
            'total_problems': len(problems),
            'total_completions': len(problems) * config.k,
            'average_completions_per_problem': config.k,
            'problems_with_at_least_one_pass': sum(1 for problem_scores in scores
                                                   if any(score > 0.99 for score in problem_scores)),
        },
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
    }

    with open(results_dir / "summary.json", 'w') as f:
        json.dump(summary, f, indent=2, cls=DateTimeEncoder)

    # Save raw scores for further analysis
    raw_data = {
        'scores': scores,
        'completions': completions,
        'problems': [{'id': p['id'], 'prompt': p['prompt'][:100] + '...' if len(p['prompt']) > 100 else p['prompt']}
                     for p in problems]
    }

    with open(results_dir / "raw_data.json", 'w') as f:
        json.dump(raw_data, f, indent=2, cls=DateTimeEncoder)

    return summary


def main():
    parser = argparse.ArgumentParser(description="Pass@k Evaluation Script")
    parser.add_argument("--model", type=str, required=True, help="Model name or path")
    parser.add_argument("--dataset_path", type=str, default="datasets/mbpp_completion", help="Dataset path")
    parser.add_argument("--dataset_split", type=str, default="test", help="Dataset split to evaluate on")
    parser.add_argument("--test_column", type=str, default="test_list", help="Test column")
    parser.add_argument("--k", type=int, default=8, help="Number of completions per problem")
    parser.add_argument("--output_dir", type=str, default="./eval_results", help="Output directory")
    parser.add_argument("--temperature", type=float, default=1, help="Temperature for generation")
    parser.add_argument("--max_tokens", type=int, default=256, help="Maximum tokens for generation")
    parser.add_argument("--top_p", type=float, default=1, help="Top p for generation")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.9, help="GPU memory utilization")
    parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism")
    parser.add_argument("--max_model_len", type=int, default=4096, help="Maximum model length")
    parser.add_argument("--trust_remote_code", type=bool, default=True, help="Trust remote code")
    parser.add_argument("--prompt_column", type=str, default="signature_with_docstring_and_imports",
                        help="Prompt column")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
    parser.add_argument("--visible_gpus", type=str, help="Visible GPUs")
    parser.add_argument("--format", type=str, default="completion",
                        help="format of the generation for reward function to choose from")
    args = parser.parse_args()

    # Generate output directory based on model name and split if not specified
    if args.output_dir == "./eval_results":
        model_name_for_dir = args.model.replace("/", "_").replace("\\", "_")
        if "checkpoint" in model_name_for_dir:
            # Extract checkpoint info for cleaner directory name
            parts = model_name_for_dir.split("_")
            checkpoint_idx = next(i for i, part in enumerate(parts) if "checkpoint" in part)
            model_base = "_".join(parts[:checkpoint_idx])
            checkpoint_name = "_".join(parts[checkpoint_idx:])
            model_name_for_dir = f"{model_base}_{checkpoint_name}"
        # Add parameters to the output directory name for better traceability
        param_str = f"k_{args.k}_temp_{args.temperature}_top_{args.top_p}_maxlen_{args.max_tokens}"
        args.output_dir = f"./eval_results/{args.dataset_path.replace('/', '_')}_{args.dataset_split}/{model_name_for_dir}_{args.format}_{param_str}"

    config = EvaluationConfig(
        model_name=args.model,
        dataset_path=args.dataset_path,
        dataset_split=args.dataset_split,
        test_column=args.test_column,
        k=args.k,
        output_dir=args.output_dir,
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        top_p=args.top_p,
        gpu_memory_utilization=args.gpu_memory_utilization,
        tensor_parallel_size=args.tensor_parallel_size,
        max_model_len=args.max_model_len,
        trust_remote_code=args.trust_remote_code,
        prompt_column=args.prompt_column,
        batch_size=args.batch_size,
        visible_gpus=args.visible_gpus,
        format=args.format
    )

    print(f"Starting pass@k evaluation with config:")
    print(f"  Model: {config.model_name}")
    print(f"  Dataset: {config.dataset_path}")
    print(f"  Split: {config.dataset_split}")
    print(f"  k: {config.k}")
    print(f"  Tensor Parallel Size: {config.tensor_parallel_size}")
    print(f"  Visible GPUs: {config.visible_gpus if config.visible_gpus else 'All available'}")
    print(f"  Output: {config.output_dir}")
    print()

    # Load dataset
    print("Loading dataset...")
    problems = load_evaluation_dataset(config.dataset_path, split=config.dataset_split,
                                       prompt_column=config.prompt_column, test_column=config.test_column)
    print(f"Loaded {len(problems)} problems from {config.dataset_split} split")

    # Initialize model
    model = initialize_vllm_model(config)

    # Extract prompts
    prompts = [problem['prompt'] for problem in problems]

    # Generate completions
    start_time = time.time()
    completions = generate_completions(model, prompts, config)
    generation_time = time.time() - start_time

    print(f"Generated completions in {generation_time:.2f} seconds")

    # Release GPU memory after generation is complete
    print("Releasing GPU memory...")
    del model
    torch.cuda.empty_cache()
    print("GPU memory released")

    # Score completions
    start_time = time.time()
    scores = score_completions(completions, problems, config)
    scoring_time = time.time() - start_time

    print(f"Scored completions in {scoring_time:.2f} seconds")

    # Calculate metrics
    print("Calculating metrics...")
    pass_at_k_metrics = calculate_pass_at_k_metrics(scores)
    max_at_k_metrics = calculate_max_at_k(scores)

    # Print results
    print("\n" + "=" * 50)
    print("EVALUATION RESULTS")
    print("=" * 50)

    print("Pass@k metrics:")
    for metric, value in pass_at_k_metrics.items():
        print(f"{metric}: {value:.4f}")

    print("Max@k metrics:")
    for metric, value in max_at_k_metrics.items():
        print(f"{metric}: {value:.4f}")

    print(f"\nTotal problems: {len(problems)}")
    print(f"Total completions: {len(problems) * config.k}")

    # Save results
    summary = save_detailed_results(problems, completions, scores, config.output_dir, config,
                                    {'pass_at_k': pass_at_k_metrics,
                                     'max_at_k': max_at_k_metrics})

    print(f"Problems with at least one passing solution: {summary['summary_stats']['problems_with_at_least_one_pass']}")

    print(f"\nResults saved to: {config.output_dir}")
    print("=" * 50)


if __name__ == "__main__":
    main()
