# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import argparse
import os

import deepspeed
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers.deepspeed import HfDeepSpeedConfig
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available

from safe_rlhf.configs import get_deepspeed_eval_config
from safe_rlhf.datasets import SafetyPreferenceDataset, parse_dataset
from safe_rlhf.logger import set_logger_level
from safe_rlhf.models import AutoModelForScore, load_pretrained_models
from safe_rlhf.utils import (
    get_all_reduce_mean,
    is_main_process,
    seed_everything,
    str2bool,
    to_device,
)


def parse_arguments() -> argparse.Namespace:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(
        prog='deepspeed --module safe_rlhf.evaluate.cost',
        description='Evaluate a cost model in Safe-RLHF.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Model
    model_parser = parser.add_argument_group('model')
    model_parser.add_argument(
        '--model_name_or_path',
        type=str,
        help='Path to the model checkpoint or its name.',
        required=True,
    )
    model_parser.add_argument(
        '--max_length',
        type=int,
        default=512,
        help='The maximum sequence length of the model.',
    )
    model_parser.add_argument(
        '--trust_remote_code',
        type=str2bool,
        default=False,
        help='Whether to trust the remote code.',
    )

    # Dataset
    dataset_parser = parser.add_argument_group('dataset')
    dataset_parser.add_argument(
        '--datasets',
        type=parse_dataset,
        nargs='+',
        metavar='DATASET[:PROPORTION[:PATH]]',
        help='Dataset name(s) registered in the raw dataset.',
        required=True,
    )

    # Evaluation
    evaluation_parser = parser.add_argument_group('evaluation')
    evaluation_parser.add_argument(
        '--per_device_eval_batch_size',
        type=int,
        default=16,
        help='Batch size (per device) for the evaluation dataloader.',
    )
    evaluation_parser.add_argument(
        '--seed',
        type=int,
        default=42,
        help='A seed for reproducible evaluation.',
    )
    evaluation_parser.add_argument(
        '--fp16',
        type=str2bool,
        default=False,
        help='Whether to use float16 precision.',
    )
    evaluation_parser.add_argument(
        '--bf16',
        type=str2bool,
        default=False,
        help='Whether to use bfloat16 precision.',
    )
    evaluation_parser.add_argument(
        '--tf32',
        type=str2bool,
        default=None,
        help='Whether to use tf32 mix precision.',
    )

    # Logging
    logging_parser = parser.add_argument_group('logging')
    logging_parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help='Where to store the evaluation output.',
    )

    # DeepSpeed
    deepspeed_parser = parser.add_argument_group('deepspeed')
    deepspeed_parser.add_argument(
        '--local_rank',
        type=int,
        default=-1,
        help='Local rank for distributed training on GPUs',
    )
    deepspeed_parser.add_argument(
        '--zero_stage',
        type=int,
        default=0,
        help='ZeRO optimization stage for models.',
    )
    parser = deepspeed.add_config_arguments(parser)

    args = parser.parse_args()
    if args.local_rank == -1:
        parser.error('`local_rank` not set, please use DeepSpeed launcher to run this script.')
    if args.fp16 and args.bf16:
        parser.error('Cannot use both bf16 and fp16 precision.')
    if args.bf16 and not is_torch_bf16_gpu_available():
        parser.error(
            'bf16 precision is not supported on this GPU. '
            'Please disable `--bf16` flag or use another precision flag (e.g., `--fp16`).',
        )
    if args.tf32 is not None and is_torch_tf32_available():
        torch.backends.cuda.matmul.allow_tf32 = args.tf32

    return args


def main() -> None:  # pylint: disable=too-many-locals,too-many-statements
    """Main evaluation routine."""
    args = parse_arguments()

    deepspeed.init_distributed()

    args.global_rank = dist.get_rank()
    args.device = device = torch.device('cuda', args.local_rank)
    torch.cuda.set_device(args.device)
    seed_everything(args.seed)
    set_logger_level()

    dist.barrier()

    ds_config = get_deepspeed_eval_config(
        stage=args.zero_stage,
        fp16=args.fp16,
        bf16=args.bf16,
    )

    if ds_config['zero_optimization']['stage'] == 3:
        args.dstchf = HfDeepSpeedConfig(ds_config)

    model, tokenizer = load_pretrained_models(
        args.model_name_or_path,
        model_max_length=args.max_length,
        padding_side='right',
        auto_model_type=AutoModelForScore,
        trust_remote_code=args.trust_remote_code,
    )
    model, *_ = deepspeed.initialize(model=model, config=ds_config)
    model.eval()

    dataset = SafetyPreferenceDataset(args.datasets, tokenizer=tokenizer)
    dataloader = DataLoader(
        dataset,
        collate_fn=dataset.get_collator(),
        sampler=DistributedSampler(dataset, shuffle=True),
        batch_size=args.per_device_eval_batch_size,
    )

    progress_bar = tqdm(
        total=len(dataloader),
        desc='Evaluating',
        position=0,
        leave=True,
        disable=not is_main_process(),
    )

    num_correct_predictions = 0
    num_correct_sign_predictions = 0
    num_total_predictions = 0
    scores = 0
    storage_scores = []
    storage_safe_scores, storage_unsafe_scores = [], []

    if is_main_process() and not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    dist.barrier()

    for batch in dataloader:
        batch = to_device(batch, device)
        with torch.no_grad():
            safer_input_ids = batch['safer_input_ids']
            safer_attention_mask = batch['safer_attention_mask']
            unsafer_input_ids = batch['unsafer_input_ids']
            unsafer_attention_mask = batch['unsafer_attention_mask']

            # size = (B, 1)
            end_scores = model(
                torch.cat([safer_input_ids, unsafer_input_ids], dim=0),
                torch.cat([safer_attention_mask, unsafer_attention_mask], dim=0),
            ).end_scores

            # Hints: safer examples are supposed to have lower costs,
            #        unsafer examples are supposed to have higher costs
            # size = (B,)
            lower_end_costs, higher_end_costs = end_scores.squeeze(dim=-1).chunk(
                chunks=2,
                dim=0,
            )

            # save to file
            with open(
                os.path.join(args.output_dir, f'scores_{args.global_rank}.txt'),
                mode='a',
                encoding='utf-8',
            ) as f:
                for i in range(lower_end_costs.size(0)):
                    text = tokenizer.decode(
                        safer_input_ids[i],
                        skip_special_tokens=True,
                    ).replace('\n', ' ')
                    f.write(f'{text}\t{lower_end_costs[i].item()}\n')
                    text = tokenizer.decode(
                        unsafer_input_ids[i],
                        skip_special_tokens=True,
                    ).replace('\n', ' ')
                    f.write(f'{text}\t{higher_end_costs[i].item()}\n')

            num_correct_predictions += (lower_end_costs < higher_end_costs).sum()
            num_correct_sign_predictions += (lower_end_costs * batch['safer_safety_sign'] < 0).sum()
            num_correct_sign_predictions += (
                higher_end_costs * batch['unsafer_safety_sign'] < 0
            ).sum()

            num_total_predictions += lower_end_costs.size(0)
            scores += lower_end_costs.mean().float()
            storage_scores.append(lower_end_costs)
            storage_scores.append(higher_end_costs)

            for i in range(lower_end_costs.size(0)):
                if batch['safer_safety_sign'][i] >= 0:
                    storage_safe_scores.append(lower_end_costs[i].unsqueeze(dim=0))
                else:
                    storage_unsafe_scores.append(lower_end_costs[i].unsqueeze(dim=0))

                if batch['unsafer_safety_sign'][i] >= 0:
                    storage_safe_scores.append(higher_end_costs[i].unsqueeze(dim=0))
                else:
                    storage_unsafe_scores.append(higher_end_costs[i].unsqueeze(dim=0))

        progress_bar.update(1)

    accuracy = num_correct_predictions / num_total_predictions
    accuracy_sign = num_correct_sign_predictions / (2 * num_total_predictions)
    scores = scores / len(dataloader)
    storage_scores = torch.cat(storage_scores, dim=0)
    storage_safe_scores = torch.cat(storage_safe_scores, dim=0)
    storage_unsafe_scores = torch.cat(storage_unsafe_scores, dim=0)

    if dist.is_initialized() and dist.get_world_size() > 1:
        accuracy = get_all_reduce_mean(accuracy).item()
        scores = get_all_reduce_mean(scores).item()
        accuracy_sign = get_all_reduce_mean(accuracy_sign).item()

        # gather storage_scores
        storage_scores_lst = [
            torch.zeros_like(storage_scores) for _ in range(dist.get_world_size())
        ]
        dist.all_gather(storage_scores_lst, storage_scores)
        storage_scores = torch.cat(storage_scores_lst, dim=0)

        # gather storage_safe_scores and storage_unsafe_scores
        index = torch.tensor(
            storage_safe_scores.size(0),
            dtype=torch.long,
            device=storage_safe_scores.device,
        )
        index_lst = [torch.zeros_like(index) for _ in range(dist.get_world_size())]
        dist.all_gather(index_lst, index)

        storage_scores_cated = torch.cat([storage_safe_scores, storage_unsafe_scores], dim=0)
        storage_scores_lst = [
            torch.zeros_like(storage_scores_cated) for _ in range(dist.get_world_size())
        ]
        dist.all_gather(storage_scores_lst, storage_scores_cated)

        storage_safe_scores_lst, storage_unsafe_scores_lst = [], []
        for i in range(dist.get_world_size()):
            storage_safe_scores_lst.append(storage_scores_lst[i][: index_lst[i]])
            storage_unsafe_scores_lst.append(storage_scores_lst[i][index_lst[i] :])

        storage_safe_scores = torch.cat(storage_safe_scores_lst, dim=0)
        storage_unsafe_scores = torch.cat(storage_unsafe_scores_lst, dim=0)

    if is_main_process():
        print(f'accuracy: {accuracy:.4f}, accuracy_sign: {accuracy_sign:.4f}, scores: {scores:.4f}')

        np.save(
            os.path.join(args.output_dir, 'scores.npy'),
            storage_scores.float().cpu().numpy(),
        )
        np.save(
            os.path.join(args.output_dir, 'safe_scores.npy'),
            storage_safe_scores.float().cpu().numpy(),
        )
        np.save(
            os.path.join(args.output_dir, 'unsafe_scores.npy'),
            storage_unsafe_scores.float().cpu().numpy(),
        )

        # Plot the distribution of scores
        scores = storage_scores.to(dtype=torch.float32).cpu().numpy()
        hist, bin_edges = np.histogram(scores, bins=1000, range=(min(scores), max(scores)))
        plt.bar(bin_edges[:-1], hist, width=0.5, align='edge')
        plt.xlabel('Score range')
        plt.ylabel('Frequency')
        plt.title('Distribution of scores')
        plt.savefig(os.path.join(args.output_dir, 'distribution.png'))
        plt.cla()

        # Plot the distribution of safe scores
        safe_scores = storage_safe_scores.to(dtype=torch.float32).cpu().numpy()
        hist, bin_edges = np.histogram(safe_scores, bins=1000, range=(min(scores), max(scores)))
        plt.bar(bin_edges[:-1], hist, width=0.5, align='edge')
        plt.xlabel('Score range')
        plt.ylabel('Frequency')
        plt.title('Distribution of safe scores')
        plt.savefig(os.path.join(args.output_dir, 'safe_distribution.png'))
        plt.cla()

        # Plot the distribution of unsafe scores
        unsafe_scores = storage_unsafe_scores.to(dtype=torch.float32).cpu().numpy()
        hist, bin_edges = np.histogram(unsafe_scores, bins=1000, range=(min(scores), max(scores)))
        plt.bar(bin_edges[:-1], hist, width=0.5, align='edge')
        plt.xlabel('Score range')
        plt.ylabel('Frequency')
        plt.title('Distribution of unsafe scores')
        plt.savefig(os.path.join(args.output_dir, 'unsafe_distribution.png'))
        plt.cla()

    dist.barrier()


if __name__ == '__main__':
    main()
