# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import argparse
import json
import os
import time
import csv
import numpy as np
import torch
from tqdm import tqdm
import sys
from rich.console import Console
from rich.table import Table
from safe_rlhf.configs.constants import PROMPT_INPUT
from safe_rlhf.models import load_pretrained_models, AutoModelForScore
from safe_rlhf.utils import to_device
from safe_rlhf.datasets import PromptOnlyDataset, parse_dataset, PromptOnlyRawDataset


def parse_arguments() -> argparse.Namespace:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(
        description='Evaluate models with gpt4',
    )

    # Model
    model_parser = parser.add_argument_group('model')
    model_parser.add_argument(
        '--response_path',
        type=str,
        help='the name or path of the first model (champion) in the arena to load from',
        required=True,
    )
    model_parser.add_argument(
        '--reward_model_name_or_path',
        type=str,
        help='the name or path of the reward model to load from',
    )
    # model_parser.add_argument(
    #     '--cost_model_name_or_path',
    #     type=str,
    #     help='the name or path of the cost model to load from',
    # )
    model_parser.add_argument(
        '--max_length',
        type=int,
        default=512,
        help='The maximum sequence length of the model.',
    )
    # Dataset
    # dataset_parser = parser.add_argument_group('dataset')
    # dataset_parser.add_argument(
    #     '--datasets',
    #     type=parse_dataset,
    #     nargs='+',
    #     metavar='DATASET[:PROPORTION[:PATH]]',
    #     help='Dataset name(s) registered in the raw dataset.',
    #     required=True,
    # )

    # Logging
    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help='Where to store the eval output.',
    )

    return parser.parse_args()


def generate_answer(responses) -> list[str]:
    answers = []
    for response in responses:
        input = response['input']
        answer = response['answers'][0]
        prompt = PROMPT_INPUT.format(input=input)
        answers.append(prompt+answer)
    return answers
    # model, tokenizer = load_pretrained_models(
    #     model_name_or_path,
    #     model_max_length=max_length,
    #     auto_device_mapping=True,
    #     trust_remote_code=True,
    # )
    # # dataset = PromptOnlyDataset(datasets, tokenizer).rawdata
    # answers = []
    # print(f'Generating answers with {model_name_or_path}')
    # for data in tqdm(dataset):
    #     prompt = PROMPT_INPUT.format(input=data['input'])
    #     input_ids = to_device(
    #         tokenizer(prompt, return_tensors='pt'),
    #         device=('cuda' if torch.cuda.is_available() else None),
    #     )
    #     output_ids = model.generate(
    #         **input_ids,
    #         max_length=max_length,
    #     )
    #     # answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt) :]
    #     answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    #     # print(answer)
    #     answers.append(answer)
    # return answers


def model_eval(inputs, model_name_or_path: str, max_length):
    model, tokenizer = load_pretrained_models(
        model_name_or_path,
        model_max_length=max_length,
        auto_model_type=AutoModelForScore,
        auto_device_mapping=True,
        trust_remote_code=True,
    )
    scores = []
    for input in tqdm(inputs):
        input_ids = to_device(
            tokenizer(input, return_tensors='pt'),
            device=('cuda' if torch.cuda.is_available() else None),
        )
        score = model(**input_ids).end_scores.item()
        # answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt) :]
        scores.append(score)
    return scores


def main() -> None:
    """The main function."""
    
    
    args = parse_arguments()
    model_name = os.path.basename(os.path.normpath(args.response_path))[:-5]
    file_name = model_name + ".json"
    responses = json.load(open(args.response_path, mode='r', encoding='utf-8'))
    
    reward_model_name = os.path.basename(os.path.normpath(args.reward_model_name_or_path))
    # cost_model_name = os.path.basename(os.path.normpath(args.cost_model_name_or_path))
    answers = generate_answer(responses)
    print(f"Evaluating with {reward_model_name} ...")
    rewards = model_eval(answers, args.reward_model_name_or_path, args.max_length)
    # print(f"Evaluating with {cost_model_name} ...")
    # costs = model_eval(answers, args.cost_model_name_or_path, args.max_length)
    
    data = []
    
    
    safety = 0
    for response, reward in zip(responses, rewards):
        # if cost <= 0.0:
            # safety += 1
        data.append(
            {
                'input': response['input'],
                'response': response['answers'][0],
                'reward': reward,
                # 'cost': cost,
            }
        )
    
    # print("Safety Rate: ", safety / len(responses))
    print("Average Reward: ", sum(rewards) / len(rewards))
    # print("Average Cost: ", sum(costs) / len(costs))
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(os.path.join(args.output_dir, file_name), mode='w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
    
if __name__ == '__main__':
    sys.exit(main())
    

# python model_eval.py \
#     --response_path ./generate/pku/satac_sampletag_3_1.json \
#     --reward_model_name_or_path "../../model/beaver-7b-v1.0-reward" \
# 	  --cost_model_name_or_path "../../model/beaver-7b-v1.0-cost" \
#     --output_dir ./