import os
import json
import numpy as np
from pathlib import Path
from typing import Optional, Union, List

from transformers import (
    PreTrainedModel,
    PreTrainedTokenizer
)

import numpy as np

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
# from lm_eval.tasks import initialize_tasks
from lm_eval.utils import make_table

TASK_TO_NUM_FEWSHOT = {
    "arc_challenge": 25,
    "hellaswag": 10,
    "truthfulqa": 0,
    "mmlu": 5,
    "winogrande": 5,
    "gsm8k": 5
}


def _handle_non_serializable(o):
    if isinstance(o, np.int64) or isinstance(o, np.int32):
        return int(o)
    elif isinstance(o, set):
        return list(o)
    else:
        return str(o)


def calculate_average(results):
    avg_values = {}
    
    acc_values = []
    acc_norm_values = []
    
    if "results" in results:
        for task_name, task_results in results["results"].items():
            if "acc" in task_results:
                acc_values.append(task_results["acc"])
            if "acc_norm" in task_results:
                acc_norm_values.append(task_results["acc_norm"])
    
    if acc_values:
        avg_values["acc"] = np.mean(acc_values)
    if acc_norm_values:
        avg_values["acc_norm"] = np.mean(acc_norm_values)
    
    group_acc_values = []
    if "groups" in results:
        for group_name, group_results in results["groups"].items():
            if "acc" in group_results:
                group_acc_values.append(group_results["acc"])
    
    if group_acc_values:
        avg_values["group_acc"] = np.mean(group_acc_values)
    
    return avg_values


def evaluate_fewshot(
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        task: Union[str, List[str]],
        num_fewshot: int,
        eval_batch_size: Optional[int] = 4,
        log: Optional[bool] = True,
        output_path: Optional[str] = None,
        seed: Optional[int] = 42,
):
    # initialize_tasks(verbosity="WARNING")
    lm = HFLM(
        pretrained=model,
        tokenizer=tokenizer,
        batch_size=eval_batch_size,
        device_map="auto"
    )
    results = evaluator.simple_evaluate(
        model=lm,
        tasks=task,
        num_fewshot=num_fewshot,
        batch_size=eval_batch_size,
        random_seed=seed,
        numpy_random_seed=seed,
        torch_random_seed=seed,
    )

    if log:
        print(make_table(results))
        
        if "groups" in results:
            print(make_table(results, "groups"))
    
    if output_path:
        f = open(output_path, "a")
        print(make_table(results), file=f)
        if "groups" in results:
            print(make_table(results, "groups"), file=f)
        
        avg_values = calculate_average(results)
        if avg_values:
            print("\nAverage:", file=f)
            for metric_name, avg_value in avg_values.items():
                formatted_value = f"{avg_value:.4f}"
                print(f"{metric_name}_average: {formatted_value}", file=f)
            
            # 单独输出一个总体平均准确率用于比较
            if "acc" in avg_values:
                print(f"\nOverall average accuracy: {avg_values['acc']:.4f}", file=f)
        
        f.close() 
        
    return results
