import os
import sys
import json
import torch
import numpy as np
import argparse
import logging
import re
from pathlib import Path
from typing import List, Dict, Any, Union

from lm_eval import tasks, evaluator
from lm_eval.api.registry import get_model
from logits_processors.primal_threshold_processor_special import PrimalThresholdProcessor, PrimalThresholdModelWrapper
from transformers import (
    TopKLogitsWarper
)
from utils.serialization_utils import NumpyEncoder, convert_to_serializable, make_array_horizontal

logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate different sampling methods (including Primal Thresholding) on selected math datasets")

    parser.add_argument("--model_name", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="HuggingFace model name or path")
    parser.add_argument("--alphas", type=str, default="1.5", help="Space-separated list of Alpha parameters for Primal Thresholding")
    parser.add_argument("--mus", type=str, default="1e-4", help="Space-separated list of Mu parameters (L0 penalty) for Primal Thresholding")
    parser.add_argument("--temperatures", type=str, default="1.0", help="Space-separated list of temperatures for sampling (0 means greedy decoding)")
    parser.add_argument("--k_max", type=int, default=50, help="Maximum k for Primal Thresholding (default: 50)")
    
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size for evaluation")
    parser.add_argument("--num_fewshot", type=int, default=5, help="Number of few-shot examples")
    parser.add_argument("--output_path", type=str, default="./primal_sampling_results.json", help="Path to save the aggregated results")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
    parser.add_argument("--log_level", type=str, default="WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level")
    parser.add_argument("--task", type=str, default="gsm8k_cot_llama", help="Task to evaluate on")

    parser.add_argument("--apply_chat_template", action="store_true", help="Whether to apply chat template to the prompt (defaults to False if not specified)")
    parser.add_argument("--fewshot_as_multiturn", action="store_true", help="Whether to provide fewshot examples as a multiturn conversation (requires --apply_chat_template)")

    return parser.parse_args()

def evaluate_model_with_processor(
    model_name: str,
    processor: Any,
    processor_name: str,
    batch_size: int,
    num_fewshot: int,
    task_name: str,
    gpu_id: int = 0,
    temperature: float = 0.0,
    verbose: bool = False,
    apply_chat_template: bool = False,
    fewshot_as_multiturn: bool = False,
) -> tuple[Dict[str, Any], List[int]]:
    device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"
    model_args = {"device": device, "batch_size": batch_size, "trust_remote_code": True}
    if verbose:
        logger.info(f"Loading model {model_name} with {processor_name}...")
    model = get_model("hf")(pretrained=model_name, **model_args)

    generation_kwargs = {"top_k": 0, "top_p": 1.0, "do_sample": False, "temperature": None}
    if temperature > 0:
        generation_kwargs["temperature"] = temperature
        generation_kwargs["do_sample"] = True

    if isinstance(processor, PrimalThresholdProcessor):
        if verbose:
            logger.info(f"Using PrimalThresholdModelWrapper for {processor_name} to enable batching")
        processor.clear_best_k_values()
        model._model = PrimalThresholdModelWrapper(model._model, processor)
    elif processor is not None:
        if isinstance(processor, TopKLogitsWarper):
            generation_kwargs["top_k"] = processor.top_k
            processor = None
        if processor is not None:
            generation_kwargs["logits_processor"] = [processor]

    tasks_to_eval = [task_name]

    if verbose:
        logger.info(f"Evaluating {model_name} with {processor_name} on {task_name}...")
        logger.info(f"Generation kwargs: {generation_kwargs}")
        logger.info(f"Batch size: {batch_size}")
        logger.info(f"Apply chat template: {apply_chat_template}")
        logger.info(f"Fewshot as multiturn: {fewshot_as_multiturn}")

    results = evaluator.simple_evaluate(
        model=model,
        tasks=tasks_to_eval,
        num_fewshot=num_fewshot,
        batch_size=batch_size,
        gen_kwargs=generation_kwargs,
        limit=None,
        bootstrap_iters=1000,
        apply_chat_template=apply_chat_template,
        fewshot_as_multiturn=fewshot_as_multiturn
    )
    if verbose:
        logger.info(f"Results for {processor_name} on {task_name}: {results}")

    task_metrics = {}
    if results and 'results' in results and task_name in results['results']:
        task_metrics = results['results'][task_name]

    individual_scores = []
    try:
        samples_data = results.get('samples', {})
        if isinstance(samples_data, dict):
            samples_list = samples_data.get(task_name, [])
        elif isinstance(samples_data, list):
            samples_list = samples_data
        else:
            samples_list = []
            if verbose:
                logger.warning(f"Unexpected type for samples_data: {type(samples_data)}. Expected dict or list.")

        if verbose and not samples_list:
             logger.warning(f"No samples found for task {task_name} in results.")

        for sample in samples_list:
            if not isinstance(sample, dict):
                if verbose:
                    logger.warning(f"Skipping non-dict sample: {type(sample)}")
                continue

            score = None
            if 'exact_match' in sample:
                score = sample.get('exact_match')
            elif 'math_verify' in sample:
                score = sample.get('math_verify')
            elif 'acc' in sample:
                score = sample.get('acc')
            elif 'accuracy' in sample:
                score = sample.get('accuracy')

            if score is not None:
                try:
                    int_score = int(round(float(score)))
                    individual_scores.append(int_score)
                    if verbose:
                        logger.info(f"Extracted score for sample: {int_score}")
                except (ValueError, TypeError) as conversion_error:
                    if verbose:
                        logger.warning(f"Could not convert score '{score}' to int for sample: {sample}. Error: {conversion_error}")
            elif verbose:
                logger.debug(f"No recognized score key (exact_match, math_verify, acc, accuracy) found in sample: {list(sample.keys())}")

    except Exception as e:
        if verbose:
            logger.warning(f"Error extracting individual scores: {e}", exc_info=True)

    all_best_k_for_this_run = []
    if isinstance(processor, PrimalThresholdProcessor):
        all_best_k_for_this_run = processor.get_all_best_k_values()
        if verbose:
            logger.info(f"Collected best_k values for {processor_name} on {task_name}: {len(all_best_k_for_this_run)} batches.")

    return {
        "model": model_name,
        "processor": processor_name,
        "task": task_name,
        "num_fewshot": num_fewshot,
        "temperature": temperature,
        "metrics": task_metrics,
        "individual_scores": individual_scores
    }, all_best_k_for_this_run

def main():
    args = parse_args()
    logging_level = getattr(logging, args.log_level)
    print(f"--- Evaluating Task: {args.task} with Hyperparameter (k_max default: {args.k_max}) ---")

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging_level,
    )
    lm_eval_logger = logging.getLogger('lm_eval')
    lm_eval_logger.setLevel(logging.WARNING if not args.verbose else logging_level)
    transformers_logger = logging.getLogger('transformers')
    transformers_logger.setLevel(logging.ERROR if not args.verbose else logging_level)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu"

    try:
        from transformers import AutoConfig
        config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=True)
        vocab_size = config.vocab_size
        logger.info(f"Detected vocab size: {vocab_size}")
    except Exception as e:
        logger.error(f"Could not automatically determine vocab size for {args.model_name}. Please ensure the model exists. Error: {e}")
        sys.exit(1)

    try:
        temperatures = [float(t.strip()) for t in args.temperatures.split()] 
        alphas = [float(a.strip()) for a in args.alphas.split()]   
        mus = [float(m.strip()) for m in args.mus.split()] 
    except ValueError as e:
        logger.error(f"Invalid format for temperatures, alphas, or mus. Please provide space-separated numbers. Error: {e}")
        sys.exit(1)

    all_results = []

    for temp in temperatures:
        for alpha_val in alphas:
            for mu_val in mus:
                current_run_results = []
                logger.info(f"--- Running evaluation for: Task={args.task}, Temp={temp}, Alpha={alpha_val}, Mu={mu_val}, k_max={args.k_max} ---")
                print(f"--- Running evaluation for: Task={args.task}, Temp={temp}, Alpha={alpha_val}, Mu={mu_val}, k_max={args.k_max} ---")

                primal_processor = PrimalThresholdProcessor(
                    alpha=alpha_val,
                    mu=mu_val,
                    k_max=args.k_max,
                    device=device
                )
                temp_str = str(temp).replace('.', 'p')
                alpha_str = str(alpha_val).replace('.', 'p')
                mu_str = str(mu_val).replace('.', 'p')
                primal_processor_name = f"primal_alpha{alpha_str}_mu{mu_str}_kmax{args.k_max}_temp{temp_str}"
                
                if args.verbose:
                    logger.info(f"Starting Primal Thresholding evaluation: {primal_processor_name}")

                primal_result, best_k_values_from_primal = evaluate_model_with_processor(
                    model_name=args.model_name,
                    processor=primal_processor,
                    processor_name=primal_processor_name,
                    batch_size=args.batch_size,
                    num_fewshot=args.num_fewshot,
                    task_name=args.task,
                    gpu_id=args.gpu_id,
                    temperature=temp,
                    verbose=args.verbose,
                    apply_chat_template=args.apply_chat_template,
                    fewshot_as_multiturn=args.fewshot_as_multiturn
                )
                primal_result['alpha'] = alpha_val
                primal_result['mu'] = mu_val
                primal_result['k_max_primal'] = args.k_max
                current_run_results.append(primal_result)

                if best_k_values_from_primal:
                    base_output_path = Path(args.output_path)
                    npy_filename_suffix = f"{args.task}_temp{temp_str}_alpha{alpha_str}_mu{mu_str}_kmax{args.k_max}_best_k.npy"
                    npy_filename = base_output_path.parent / f"{base_output_path.stem}_{npy_filename_suffix}"
                    try:
                        np.save(npy_filename, np.array(best_k_values_from_primal, dtype=np.int16))
                        if args.verbose:
                            logger.info(f"Saved best_k_values for {primal_processor_name} on {args.task} to {npy_filename}")
                    except Exception as e:
                        logger.error(f"Error saving best_k_values to .npy file {npy_filename}: {e}", exc_info=True)

                    if len(best_k_values_from_primal) > 0:
                        mean_k = np.mean(best_k_values_from_primal)
                        derived_top_k = int(round(mean_k))
                        if args.verbose:
                            logger.info(f"Mean k from Primal Thresholding ({primal_processor_name}): {mean_k:.2f}, Rounded to: {derived_top_k}")

                        if derived_top_k >= 1:
                            top_k_derived_processor = TopKLogitsWarper(top_k=derived_top_k)
                            top_k_derived_processor_name = f"top_k_derived_from_primal_mean_k{derived_top_k}_temp{temp_str}_alpha{alpha_str}_mu{mu_str}_kmax{args.k_max}"
                            
                            if args.verbose:
                                logger.info(f"Starting Top-K evaluation with derived k: {top_k_derived_processor_name}")

                            top_k_result, _ = evaluate_model_with_processor(
                                model_name=args.model_name,
                                processor=top_k_derived_processor,
                                processor_name=top_k_derived_processor_name,
                                batch_size=args.batch_size,
                                num_fewshot=args.num_fewshot,
                                task_name=args.task,
                                gpu_id=args.gpu_id,
                                temperature=temp,
                                verbose=args.verbose,
                                apply_chat_template=args.apply_chat_template,
                                fewshot_as_multiturn=args.fewshot_as_multiturn
                            )
                            top_k_result['derived_from_alpha'] = alpha_val
                            top_k_result['derived_from_mu'] = mu_val
                            top_k_result['derived_from_k_max_primal'] = args.k_max
                            top_k_result['derived_k'] = derived_top_k
                            current_run_results.append(top_k_result)
                        elif args.verbose:
                            logger.warning(f"Derived Top-K value is {derived_top_k}, which is less than 1. Skipping Top-K evaluation for this run (Temp={temp}, Alpha={alpha_val}, Mu={mu_val}).")
                    elif args.verbose:
                        logger.info(f"No best_k_values collected from Primal Thresholding for this run (Temp={temp}, Alpha={alpha_val}, Mu={mu_val}), skipping derived Top-K evaluation.")
                elif args.verbose:
                    logger.info(f"No best_k_values collected from Primal Thresholding (list is None or empty) for this run (Temp={temp}, Alpha={alpha_val}, Mu={mu_val}), skipping derived Top-K evaluation and .npy save.")
                
                all_results.extend(current_run_results)

    output_path = Path(args.output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    try:
        with open(output_path, 'w') as f:
            json.dump(all_results, f, cls=NumpyEncoder, indent=2)

        with open(output_path, 'r') as f:
            json_str = f.read()

        pattern_scores = r'(\s*"individual_scores":\s*\[)([^\]]*)(\])'
        compact_json_str = re.sub(pattern_scores, make_array_horizontal, json_str, flags=re.DOTALL)

        with open(output_path, 'w') as f:
            f.write(compact_json_str)

    except Exception as e:
        logger.error(f"Error saving results with post-processing: {e}")
        try:
            with open(output_path, 'w') as f:
                 json.dump(convert_to_serializable(all_results), f, indent=2)
            print(f"Fallback: Serialized results saved to {output_path}")
        except Exception as e2:
            logger.error(f"Failed even with fallback serialization: {e2}")

    print("\nSummary of All Results:")
    print("-" * 120)
    print(f"{'Task':<25} {'Method':<55} {'Temp':<5} {'Alpha':<5} {'Mu':<8} {'Accuracy':<10}")
    print("-" * 120)
    for result in all_results:
        processor_name = result['processor']
        task_name = result['task']
        temp_disp = result.get('temperature', 'N/A')
        alpha_disp = result.get('alpha', result.get('derived_from_alpha', 'N/A'))
        mu_disp = result.get('mu', result.get('derived_from_mu', 'N/A'))

        try:
            metrics = result['metrics']
            accuracy_keys = ['exact_match', 'exact_match,strict-match', 'exact_match,flexible-extract',
                           'acc', 'accuracy', 'math_verify']
            accuracy = None
            for key in accuracy_keys:
                if key in metrics:
                    accuracy = float(metrics[key])
                    break
            if accuracy is None:
                acc_keys = [k for k in metrics.keys() if 'match' in k.lower() or 'acc' in k.lower()]
                if acc_keys:
                    accuracy = float(metrics[acc_keys[0]])
                else:
                    raise KeyError(f"No accuracy metric found in metrics. Available keys: {list(metrics.keys())}")
            
            temp_str_disp = f"{temp_disp:<5.1f}" if isinstance(temp_disp, float) else f"{str(temp_disp):<5}"
            alpha_str_disp = f"{alpha_disp:<5.1f}" if isinstance(alpha_disp, float) else f"{str(alpha_disp):<5}"
            mu_str_disp = f"{mu_disp:<8}" if isinstance(mu_disp, (float, str)) else f"{str(mu_disp):<8}"

            print(f"{task_name:<25} {processor_name:<55} {temp_str_disp} {alpha_str_disp} {mu_str_disp} {accuracy:.4f}")

        except KeyError as e:
            print(f"{task_name:<25} {processor_name:<55} ERROR: Key not found - {str(e)}")
        except Exception as e:
            print(f"{task_name:<25} {processor_name:<55} ERROR: {str(e)}")
    print("-" * 120)
    print(f"Results saved to: {output_path}")

if __name__ == "__main__":
    main()
