import os
import argparse
import re
import torch
import pandas as pd
from utils.answer_extraction import process_output, process_output_qvq
from utils.report_generator import judge_accuracy
from pandas.errors import ParserError

os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"
from vllm import LLM, SamplingParams


def arg_sanitizer(args):
    if args.temperature < 0 or args.temperature > 1:
        print("E: temperature should be in [0, 1]")
        exit(1)
    if args.tensor_parallel_size > torch.cuda.device_count() or args.tensor_parallel_size < 1:
        print(f"E: tensor_parallel_size={args.tensor_parallel_size} expect in [1, {torch.cuda.device_count()}]")
        exit(1)
    args.data_input = list(set(args.data_input))
    for data_input in args.data_input:
        if not os.path.exists(data_input):
            print(f"E: {data_input} does not exist")
            exit(1)
    if not args.do_sample and args.temperature > 0:
        print("Warning: do_sample is False, temperature will be ignored")

    return args


def main():
    parser = argparse.ArgumentParser(description="evaluate model final answer")
    parser.add_argument('--model-id', type=str, required=True, help='model name')
    parser.add_argument('--data-input', type=str, required=True, help='data input file', action='extend', nargs='+')
    parser.add_argument('--ignore-fail', action='store_true', help='ignore when fail to extract answer')
    parser.add_argument('--do-sample', action='store_true', help='use sampling or greedy decoding')
    parser.add_argument('--do-not-extract', action='store_true', help='pass the entire model output to judge')
    parser.add_argument('--task', type=str, default='final_answer', choices=['final_answer', 'self_correction'],
                        help='task to evaluate')
    parser.add_argument('--quantization-method', type=str, default='bitsandbytes',
                        choices=['bitsandbytes', 'none', 'awq'], help='quantization method to use')
    parser.add_argument('--max-length', type=int, default=512, help='max length of the generated text')
    parser.add_argument('--max-model-len', type=int, default=None, help='max length of the model')
    parser.add_argument('--temperature', type=float, default=0.1, help='temperature for sampling')
    parser.add_argument('--tensor_parallel_size', type=int, default=1, help='tensor parallel size')
    parser.add_argument('--answer-type', type=str, default='full', choices=['full', 'short'],
                        help='type of answer to extract')
    args = parser.parse_args()
    args = arg_sanitizer(args)

    # create sampling parameters
    if args.do_sample:
        sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_length)
    else:
        sampling_params = SamplingParams(temperature=0, max_tokens=args.max_length)

    # load template
    template_type_ext = "_short" if args.answer_type == "short" else ""
    template_type_ext += '_non_processed' if args.do_not_extract else ''
    template_file_path = f'./static/lm_as_judge_{args.task}{template_type_ext}.template'
    if not os.path.exists(template_file_path):
        raise ValueError(f'Combination of task={args.task}, answer_type={args.answer_type}, '
                         f'do_not_extract={args.do_not_extract} is not supported')
    with open(template_file_path) as f:
        template = f.read()

    # initialize the model
    if args.quantization_method == 'bitsandbytes':
        model = LLM(model=args.model_id,
                    max_model_len=args.max_model_len,
                    quantization='bitsandbytes',
                    load_format='bitsandbytes',
                    tensor_parallel_size=args.tensor_parallel_size)
    elif args.quantization_method == 'awq':
        model = LLM(model=args.model_id,
                    max_model_len=args.max_model_len,
                    quantization='AWQ')
    else:
        model = LLM(model=args.model_id,
                    max_model_len=args.max_model_len,
                    tensor_parallel_size=args.tensor_parallel_size)
    tokenizer = model.get_tokenizer()

    for data_input in args.data_input:
        # initialize the data
        try:
            data = pd.read_json(data_input, lines=True, dtype=str)
        except ParserError as e:
            print(f"E: {data_input} is not a valid jsonl file")
            exit(1)
        if args.task == 'final_answer':
            if 'qvq' in os.path.basename(data_input):
                data = process_output_qvq(data)
            else:
                data = process_output(data, raw=args.do_not_extract)
            assert data['model_answer'].notnull().all(), "E: model_answer fail to extract correctly"
        queue = []
        for _, row in data.iterrows():
            if args.task == 'final_answer':
                model_answer = row['model_answer']
                if args.answer_type == 'short' and 'short_answer' not in row:
                    raise ValueError("E: short_answer not found in the data; first augment the output with short_answer")
                ground_truth = row['answer'] if args.answer_type == 'full' else row['short_answer'].capitalize()
                question = row['question']
            elif args.task == 'self_correction':
                model_answer = row['output_text']
                ground_truth = row['correction_human']
                question = row['correction_context']
            else:
                raise NotImplementedError(f"Task {args.task} is not supported")
            # create the prompt
            prompt = template.format(model_output=model_answer, ground_truth=ground_truth, question=question)
            if tokenizer.chat_template:
                conversation = [
                    {'role': 'system', 'content': 'You are a helpful assistant.'},
                    {'role': 'user', 'content': prompt}
                ]
                prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
            queue.append(prompt)

        responses = model.generate(queue, sampling_params=sampling_params, use_tqdm=True)

        judge_out = []
        for response in responses:
            # as the template instructs the model to provide a correct/incorrect answer, at the end of the response
            # extract the last line of the response and check if it contains yes or no
            raw_response = response.outputs[0].text.strip()
            raw_response = raw_response.replace('```', '').strip()
            response = raw_response.split('\n')[-1].strip().lower()
            pattern = r"\b(yes|no)\b" if args.task == 'self_correction' else r"\b(correct|incorrect)\b"
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                response = match.group(1).strip()
                if response == 'correct':
                    response = 'yes'
                elif response == 'incorrect':
                    response = 'no'
                judge_out.append(response)
            else:
                judge_out.append(raw_response)

        data['judge'] = judge_out

        base_name = os.path.basename(data_input)
        base_name = base_name.replace('jsonl', 'csv')
        base_name = (f'judge_{"" if args.answer_type == "full" else "short_"}'
                     f'{"raw_" if args.do_not_extract else ""}{base_name}')
        # extract the path to the input file
        input_path = os.path.dirname(data_input)
        # save as csv
        print(f"Saving to: {os.path.join(input_path, base_name)}")
        data.to_csv(os.path.join(input_path, base_name), index=False)

        # calculate the accuracy
        accuracy = judge_accuracy(data)
        print(f"Accuracy: {accuracy:.2%}")


if __name__ == '__main__':
    main()