

import argparse
import dataclasses
import importlib
import json
import os
from pprint import pformat, pprint


# isort: off
import torch  # noqa: F401 # pylint: disable=unused-import # should be imported before bigbench
import bigbench.api.json_task as json_task
import bigbench.api.results as results_api
import bigbench.api.task as task_api
import bigbench.api.util as util
import bigbench.models.query_logging_model as query_logging_model
import tensorflow as tf

# isort: on
from safe_rlhf.evaluate.bigbench.model import SafeRLHFModel


gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)


def parse_arguments() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model_name_or_path',
        type=str,
        required=True,
        help='Path to the saved model.',
    )
    parser.add_argument(
        '--task_name',
        nargs='+',
        type=str,
        required=True,
        help='Name of the task.',
    )
    parser.add_argument(
        '--batch_size',
        type=int,
        help='Batch size.',
        default=4,
    )
    parser.add_argument(
        '--max_length',
        type=int,
        help='Maximum sequence length.',
        default=2048,
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        help='Output directory.',
        default='./eval_results',
    )
    parser.add_argument(
        '--max_examples',
        type=int,
        help='Maximum number of examples to evaluate.',
        default=None,
    )
    parser.add_argument(
        '--not_log_queries',
        action='store_true',
        help='Whether to log queries.',
    )
    parser.add_argument(
        '--show_progress',
        action='store_true',
        help='Whether to show progress bar.',
    )

    args, _ = parser.parse_known_args()
    return args


def _sanitize_results(scores: list[task_api.ScoreData]) -> list[task_api.ScoreData]:
    sanitized = []
    for score in scores:
        sanitized_score_dict = {}
        for key, value in score.score_dict.items():
            try:
                float(value)
            except (ValueError, TypeError):
                print(f'error: ScoreData item {key} is a {type(value)} not a float, removing')
                continue
            sanitized_score_dict[key] = value
        sanitized_score_data = task_api.ScoreData(
            score_dict=sanitized_score_dict,
            preferred_score=score.preferred_score,
            number_of_shots=score.number_of_shots,
            low_score=score.low_score,
            high_score=score.high_score,
            subtask_description=score.subtask_description,
        )

        sanitized.append(sanitized_score_data)

    return sanitized


def _backup_existing(filename: str) -> None:
    if not os.path.exists(filename):
        return

    for backup_index in range(1000000):
        new_filename = filename + f'.backup_{backup_index:06d}'
        if not os.path.exists(new_filename):
            os.rename(filename, new_filename)
            return


def main() -> None:
    args = parse_arguments()
    pprint(args)

    # load model
    print('loading model...')
    model = SafeRLHFModel(
        model_path=args.model_name_or_path,
        show_progress=args.show_progress,
        batch_size=args.batch_size,
        max_length=args.max_length,
    )
    print('model loaded.')

    # evaluate
    for task_name in args.task_name:
        print('=' * 80)
        print(f'evaluating on {task_name}...')
        task_module_name = f'bigbench.benchmark_tasks.{task_name}'
        try:
            task_module = importlib.import_module(task_module_name)
        except ModuleNotFoundError:
            print(f'Warning: Cannot find task {task_name}.')
            continue

        module_path = next(iter(task_module.__path__))
        json_path = os.path.join(module_path, 'task.json')
        task_submodule_name = f'{task_module_name}.task'

        if os.path.exists(json_path):
            task = json_task.JsonTask(json_path, max_examples=args.max_examples)
        else:
            task = util.load_programmatic_task(task_submodule_name)

        model = query_logging_model.QueryLoggingModel(model)
        results_data = results_api.ResultsFileData(
            scores=[],
            task=results_api.TaskData(task_name),
            model=model.model_data(),
            queries=[],
        )

        results = task.evaluate_model(model, max_examples=args.max_examples)
        if not isinstance(results, list):
            results = [results]

        results = _sanitize_results(scores=results)
        results = results_api.add_aggregate_scores(task_name=task_name, scores=results)

        print('results:')
        for result in results:
            print(f'{pformat(result.score_dict)}')

        results_data.scores = results

        results_dir = os.path.join(args.output_dir, os.path.split(module_path)[1], 'results')
        if not os.path.exists(results_dir):
            os.makedirs(results_dir)

        scores_file = os.path.join(
            results_dir,
            results_api.score_result_filename(model.model_data()),
        )
        _backup_existing(scores_file)
        with open(scores_file, mode='w', encoding='utf-8') as f:
            json.dump(dataclasses.asdict(results_data), f, indent=2)

        if not args.not_log_queries:
            if isinstance(task, json_task.JsonTask):
                results_data.queries = util.update_json_task_queries(
                    task=task,
                    model=model,
                    scores=results_data.scores,
                )
            else:
                results_data.queries = model.queries

            queries_file = os.path.join(
                results_dir,
                results_api.query_result_filename(model.model_data()),
            )
            _backup_existing(queries_file)
            with open(queries_file, mode='w', encoding='utf-8') as f:
                json.dump(dataclasses.asdict(results_data), f, indent=2)


if __name__ == '__main__':
    main()
