import lighteval

from datetime import timedelta

from lighteval.logging.evaluation_tracker import EvaluationTracker

from lighteval.models.transformers.transformers_model import TransformersModelConfig
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters

from lighteval.utils.imports import is_accelerate_available


if is_accelerate_available():
    from accelerate import Accelerator, InitProcessGroupKwargs
    accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
else:
    accelerator = None

def main(args):
    # output_dir
    evaluation_tracker = EvaluationTracker(output_dir=args.output_dir)

    model_config = TransformersModelConfig(
        model_name=args.pretrained,
        dtype=args.dtype,
        batch_size=args.batch_size,
        generation_parameters=args.generation_parameters,   # there is no need to set max_new_tokens, as tasks has set
    )

    pipeline_params = PipelineParameters(launcher_type=ParallelismManager.ACCELERATE,)

    pipeline = Pipeline(
        tasks=args.task,
        pipeline_parameters=pipeline_params,
        evaluation_tracker=evaluation_tracker,
        model_config=model_config,
    )

    pipeline.evaluate()
    pipeline.show_results()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="./results")
    parser.add_argument("--pretrained", type=str, required=True)
    parser.add_argument("--dtype", type=str, default="bfloat16")
    parser.add_argument("--batch_size", type=int, required=True)
    parser.add_argument("--generation_parameters", type=str, required=True)
    args = parser.parse_args()
    args.generation_parameters = eval(args.generation_parameters)

    main(args)


# CUDA_VISIBLE_DEVICES=0 python main.py --task "lighteval|aime24|0|0" --pretrained /XXX/public/DeepSeek-R1-Distill-Qwen-1.5B/ --batch_size 1 --generation_parameters "{'temperature':0.6, 'top_p':0.95}"
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 nohup accelerate launch main.py --task "lighteval|aime24|0|0" --pretrained /XXX/public/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/ --batch_size 1 --generation_parameters "{'temperature':0.6, 'top_p':0.95}" &>polar44_aime24_qwen_1.5b.log &


# Qwen2.5-1.5B
#                       | aime24 | aime25  | math_500 |  gpqa  | 
# vanilla(bf16)         | 36.67  |  23.33  |   85.20  |  39.90 |
# kivi4   (gs128rs128)  | 20.00  |  23.33  |   80.40  |  33.84 |
# polar44 (gs128rs128)  | 30.00  |  20.00  |   80.20  |  37.88 | 


# Llama-8B
#                       | aime24 | aime25  | math_500 |  gpqa  |
# vanilla (bf16)        | 50.00  |  36.67  |   91.20  |  51.52 |
# zip4 (bf16)           | 43.33  |  43.33  |   91.60  |  48.48 |    
# kivi4   (gs128rs128)  | 43.33  |  33.33  |   89.80  |  51.01 |
# polar44 (gs128rs128)  | 60.00  |  36.67  |   89.00  |  50.00 |
