# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import argparse
import json
import os
import time
import csv
import numpy as np
import torch
from tqdm import tqdm
import sys
from rich.console import Console
from rich.table import Table
from safe_rlhf.configs.constants import PROMPT_INPUT, STOP_WORDS
from safe_rlhf.models import load_pretrained_models, AutoModelForScore
from safe_rlhf.utils import to_device
from safe_rlhf.datasets import PromptOnlyDataset, parse_dataset, PromptOnlyRawDataset
from transformers import StoppingCriteria


# PROMPT_INPUT = '<s>[INST] <<SYS>>\n You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\'t know the answer to a question, please don\'t share false information.\n<</SYS>>\n{input}\n[/INST]'

def parse_arguments() -> argparse.Namespace:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(
        description='Evaluate models with gpt4',
    )
    parser.add_argument(
        '--num_per_prompt',
        type=int,
        default=1,
        help='Number of answers generated per prompt.',
    )
    # Model
    model_parser = parser.add_argument_group('model')
    model_parser.add_argument(
        '--corner_model_name_or_path',
        type=str,
        help='the name or path of the first model (champion) in the arena to load from',
        required=True,
    )
    model_parser.add_argument(
        '--tokenizer_name_or_path',
        type=str,
        help='the name or path of the first model (champion) in the arena to load from',
        default=None,
    )
    model_parser.add_argument(
        '--ckptpath',
        type=str,
        help='the name or path of the first model (champion) in the arena to load from',
        default=None,
    )
    model_parser.add_argument(
        '--max_length',
        type=int,
        default=512,
        help='The maximum sequence length of the model.',
    )
    # Dataset
    dataset_parser = parser.add_argument_group('dataset')
    dataset_parser.add_argument(
        '--datasets',
        type=parse_dataset,
        nargs='+',
        metavar='DATASET[:PROPORTION[:PATH]]',
        help='Dataset name(s) registered in the raw dataset.',
        required=True,
    )

    # Logging
    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help='Where to store the eval output.',
    )

    return parser.parse_args()


class KeywordStoppingCriteria(StoppingCriteria):
    def __init__(self, keyword_ids_list: list):
        self.keyword_ids_list = keyword_ids_list

    def __call__(
        self, input_ids: torch.LongTensor, scores=torch.FloatTensor, **kwargs
    ) -> bool:
        for keyword_ids in self.keyword_ids_list:
            if (
                len(input_ids[0]) >= len(keyword_ids)
                and input_ids[0][-len(keyword_ids):].tolist() == keyword_ids
            ):
                return True
        return False


def remove_keywords(answer, stop_words_list):
    for stop_words in stop_words_list:
        answer = answer.split(stop_words)[0]
    return answer

def generate_answer(dataset, model_name_or_path: str, tokenizer_name_or_path: str=None, max_length=512, num_per_prompt=1, ckpt_path=None):
    model, tokenizer = load_pretrained_models(
        model_name_or_path,
        tokenizer_name_or_path,
        model_max_length=max_length,
        auto_device_mapping=True,
        trust_remote_code=True,
    )
    # model.load_checkpoint(ckpt_path, tag=300)
    # model.load_checkpoint(model_name_or_path, tag=300)
    # if ckpt_path is not None:
    #     ckpt = torch.load(ckpt_path)
    #     model.load_state_dict(ckpt)
    # dataset = PromptOnlyDataset(datasets, tokenizer).rawdata
    answers = []
    # # Setup stop words
    # stop_ids_list = [tokenizer.encode(w)[1:] for w in STOP_WORDS]
    # keyword_stopping_criteria = KeywordStoppingCriteria(stop_ids_list)
    
    print(f'Generating answers with {model_name_or_path}')
    for data in tqdm(dataset):
        prompt = PROMPT_INPUT.format(input=data['input'])
        input_ids = to_device(
            tokenizer(prompt, return_tensors='pt'),
            device=('cuda' if torch.cuda.is_available() else None),
        )
        answer = []
        
        for _ in range(num_per_prompt):
            output_ids = model.generate(
                input_ids['input_ids'],
                max_length=max_length,
                num_return_sequences=1,
                # temperature=1.0,
                repetition_penalty=1.3,
                do_sample=True,
                # stopping_criteria=[keyword_stopping_criteria],
            )
            # print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
            length = output_ids.shape[-1]
            ans = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt) :]
            # ans = answer.strip()
            # ans = remove_keywords(ans, STOP_WORDS)
            # print(ans)
            answer.append(ans)
            # print("input:", prompt)
            # print("response:", ans)
        answers.append(
            {
                'input': data['input'],
                'answers': answer,
                'length': length,
            }
        )
    return answers




def main() -> None:
    """The main function."""
    args = parse_arguments()
    num_per_prompt = args.num_per_prompt
    corner_model_name = os.path.basename(os.path.normpath(args.corner_model_name_or_path))
    dataset = PromptOnlyRawDataset(args.datasets)
    answers = generate_answer(dataset, args.corner_model_name_or_path, args.tokenizer_name_or_path, args.max_length, num_per_prompt, args.ckptpath)
    # print(answers)
    file_name = f'{corner_model_name}_{num_per_prompt}.json'
    if args.output_dir is not None:
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        with open(os.path.join(args.output_dir, file_name), mode='w', encoding='utf-8') as f:
            json.dump(answers, f, indent=4, ensure_ascii=False)
    
if __name__ == '__main__':
    sys.exit(main())