'''
Usage:
python -m mix_eval.evaluate \
    --model_name {gpt_35_turbo, qwen_15_7b_chat, qwen_15_72b_chat, ...} \
    --split {close_freeform,close_multichoice,open} \
    [--output_folder OUTPUT_FOLDER] \
    [--batch_size BATCH_SIZE] \
    [--max_gpu_memory MAX_GPU_MEMORY] \
    [--verbose]
'''
import argparse
import os
import time
import warnings
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=FutureWarning)

from torch.utils.data import DataLoader

import mix_eval.api.registry
from mix_eval.models import AVAILABLE_MODELS
from mix_eval.utils.dataset import get_eval_dataset
from mix_eval.utils.common_utils import (
    set_seed, 
    cache_status, 
    read_status, 
    dict_equal,
    log_error
    )

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name", 
        type=str, 
        required=True, 
        choices=AVAILABLE_MODELS.keys(), 
        help="Model to evaluate."
        )
    parser.add_argument(
        "--split", 
        type=str, 
        required=True, 
        choices=["close_freeform", "close_multichoice", "close_freeform_hard", "close_multichoice_hard", "open", "open_hard"], 
        help="Split to evaluate."
        )
    parser.add_argument(
        "--data_path_freeform", 
        type=str, 
        default="mix_eval/data/text2text/text2text_closeended/free-form.json", 
        help="Path to free-form dataset."
        )
    parser.add_argument(
        "--data_path_multiplechoice", 
        type=str, 
        default="mix_eval/data/text2text/text2text_closeended/multiple-choice.json", 
        help="Path to multiple-choice dataset."
        )
    parser.add_argument(
        "--data_path_freeform_hard", 
        type=str, 
        default="mix_eval/data/text2text/text2text_closeended/free-form-hard.json", 
        help="Path to free-form-hard dataset."
        )
    parser.add_argument(
        "--data_path_multiplechoice_hard", 
        type=str, 
        default="mix_eval/data/text2text/text2text_closeended/multiple-choice-hard.json", 
        help="Path to multiple-choice-hard dataset."
        )
    parser.add_argument(
        "--data_path_open", 
        type=str, 
        default="mix_eval/data/text2text/text2text_openended/text2text_openended.json", 
        help="Path to open dataset."
        )
    parser.add_argument(
        "--data_path_open_hard", 
        type=str, 
        default="mix_eval/data/text2text/text2text_openended/text2text_openended_hard.json", 
        help="Path to open-hard dataset."
        )
    parser.add_argument(
        "--output_folder", 
        type=str, 
        default="mix_eval/data/model_responses", 
        help="Path to save model responses."
        )
    parser.add_argument(
        "--batch_size", 
        type=int, 
        default=32, 
        help="Batch size for evaluation."
        )
    parser.add_argument(
        "--max_gpu_memory", 
        type=str, 
        default=None, 
        help="Max memory to use for each GPU."
        )
    parser.add_argument(
        "--verbose", 
        action="store_true", 
        help="Print verbose information."
        )
    return parser.parse_args()


def eval(args):
    print(f"\n\nStart to evaluate {args.model_name}'s {args.split} split. \n\n")
    time_elapsed = 0
    start_time = time.time()
    
    response_file = os.path.join(
        args.output_folder, 
        args.split, 
        args.model_name, 
        f"{args.model_name}.jsonl"
        )
    os.makedirs(
        os.path.dirname(response_file), 
        exist_ok=True
        )
    
    # if the response file exists, check if it can resume from last run
    resume = False
    if os.path.exists(response_file):
        status = read_status(args)
        if not dict_equal(status['args'], args.__dict__):
            raise ValueError("The cached arguments are "
                            "different from those in the current run. Please check.")
        if status['status']['status'] == 'complete':
            print(f"The evaluation for {args.model_name}'s {args.split} "
                    "split is already complete. Skipping.")
            return
        with open(response_file) as f:
            lines = f.readlines()
            if len(lines) == (status['status']['batch_id'] + 1) * args.batch_size:
                resume = True
                time_elapsed += time.time() - start_time + status['status']['time_elapsed']
                start_time = time.time()
                print(f"Resuming from last run: \n{status}")
            else:
                raise ValueError(f"The response file [{response_file}] has different "
                                "lines as recorded in cached metadadta. Please check the response file. "
                                "You might consider delete the response and metadata file to start from scratch.")
    
    model = mix_eval.api.registry.get_model(args.model_name)(args)
    eval_dataset = get_eval_dataset(args)
    dataloader = DataLoader(
        eval_dataset, 
        batch_size=args.batch_size, 
        shuffle=False, 
        num_workers=32, 
        collate_fn=lambda x: x
        )
    
    for b_id, batch in enumerate(dataloader):
        if resume:
            if status['status']['batch_id'] >= b_id:
                continue
            else:
                resume = False

        if args.verbose:
            _start_time = time.time()
        model.get_responses(batch, response_file)
        if args.verbose:
            _finish_time = time.time()
            print(f"Batch {b_id} finished in {_finish_time - _start_time} seconds.")
        
        time_elapsed += time.time() - start_time
        start_time = time.time()
        
        status = {
            'batch_id': b_id,
            'time_elapsed': time_elapsed,
            'status': 'in progress'
        }
        cache_status(args, status)

    status = {
        'batch_id': b_id,
        'time_elapsed': time_elapsed,
        'status': 'complete'
    }
    cache_status(args, status)
    print(f"Finished evaluating {args.model_name}'s {args.split} split. "
          f"Used {round(time_elapsed / 60, 2)} minutes.")


if __name__ == '__main__':
    set_seed()
    args = parse_args()
    try:
        eval(args)
    except Exception as e:
        msg = (f"Error: {e}; Model: {args.model_name}; "
        f"Split: {args.split}; "
        f"Check the logfile: {args.output_folder}/{args.split}/"
        f"{args.model_name}/{args.model_name}.log")
        log_error(msg, f"{args.output_folder}/{args.split}/error.log")
        raise e
    