import torch
from tqdm import tqdm
import pandas as pd
# from eval.utils import load_dexperts_model_and_tokenizer, load_dexperts_model_and_tokenizer_vllm
# from analysis.utils import flatten_batch_results, summarize_results, trim_output
from vllm import LLM, SamplingParams
from openai import OpenAI
import jsonlines
import os
from vllm_inject.utils import *
import json, re
import evaluate, random
from transformers import AutoTokenizer
import numpy as np
from rouge import Rouge
from typing import Iterable, Dict
import gzip

BEST_COL = 'Best Answer'
ANSWER_COL = 'Correct Answers'
INCORRECT_COL = 'Incorrect Answers'
CHOICES = 'ABCD'
MC_ANSWER_PREFIX = 'The answer is:'

def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
    """
    Writes an iterable of dictionaries to jsonl
    """
    if append:
        mode = 'ab'
    else:
        mode = 'wb'
    filename = os.path.expanduser(filename)
    if filename.endswith(".gz"):
        with open(filename, mode) as fp:
            with gzip.GzipFile(fileobj=fp, mode='wb') as gzfp:
                for x in data:
                    gzfp.write((json.dumps(x) + "\n").encode('utf-8'))
    else:
        with open(filename, mode) as fp:
            for x in data:
                fp.write((json.dumps(x) + "\n").encode('utf-8'))
                
def format_best(best_ans, close=True):
    """Formats best answer to match format of reference answers"""

    best = best_ans.strip()
    if close:
        if best[-1] != '.':
            best = best + '.'
    return best

def split_multi_answer(ans, sep=';', close=True):
    """Splits string of all reference answers into a list of formatted answers"""

    answers = ans.strip().split(sep)
    split_answers = []
    for a in answers:
        a = a.strip()
        if len(a):
            if close:  # add a period after all answers
                if a[-1] != '.':
                    split_answers.append(a + '.')
                else:
                    split_answers.append(a)
            else:
                split_answers.append(a)

    return split_answers

def trim_answer(answer):
    """
    Trim generated answer for open-ended evaluation setting.
    """
    # remove spaces at the beginning and end
    answer = answer.strip()
    # remove the "Answer:" prefix if it exists
    if answer.startswith('Answer:'):
        answer = answer[len('Answer:'):].strip()
    # reformat line-breaks for long-form answers
    answer = answer.replace('[INST]', '')
    answer = answer.replace('[\/INST]', '')
    answer = re.sub(r"\[([A-Z])\]", lambda x: f"\n\n{x.group(1)}", answer)
    answer = answer.replace('\n\n', '\n')
    
    return answer

@torch.inference_mode()
def get_cnn_output(base_model,
                        tokenizer,
                        max_tokens,
                        batch_size,
                        temperature,
                        top_p,
                        use_chat_format=False,
                        icl=False,
                        system_prompt="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."):
    print("Loading data...")
    prompts = []
    answer_idxs = []
    test_data = []
    prompt_prefix = ""
    if icl == True:
        icl_context = ""
        with open("/xx/data/eval/cnn/train.jsonl", "r") as f:
            tot = 0
            for i in f:
                tot += 1
                data = json.loads(i)
                icl_context = icl_context + data["prompt"].strip() + "\n\nSummarize the above article:" + data["completion"].strip() + "\n"
                if tot >= 1:break
        prompt_prefix = icl_context + prompt_prefix
    with open("/xx/data/eval/cnn/test.jsonl") as fin:
        for line in fin:
            example = json.loads(line)
            test_data.append({
                "question": example["prompt"],
                "answer": example["completion"]
            })
            # if len(test_data) > 100: break
   

    # prompt_prefix = "The following are multiple choice questions (with answers). Please choose a correct answer.\n"
    #prompt_prefix = "The following are multiple choice questions (with answers). Please choose one of the A,B,C,D as the correct answer.\n"
    prompts = []
    chat_formatting_function = eval.templates.create_prompt_with_tulu_chat_format if use_chat_format else None
    for example in test_data:
        
        now_prompt_prefix = prompt_prefix
        prompt = example["question"].strip() + "\n\nSummarize the above article:"
        if len(prompt) + len(prompt_prefix) >= 4096:
            now_prompt_prefix = prompt_prefix[:4096 - len(prompt)]
        prompt = now_prompt_prefix + prompt
        
        # if use_chat_format:
        #     messages = [{"role": "user", "content": prompt}]
        #     prompt = chat_formatting_function(messages, add_bos=False)
        #     if prompt[-1] in ["\n", " "]:
        #         prompt += "Summarize the above article in 3 sentences."
        #     else:
        #         prompt += " Summarize the above article in 3 sentences."
        # else:
        #     prompt += "\nSummarize the above article in 3 sentences."
        prompts.append(prompt)
        
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens)
    all_results = []
    outputs = []
    for i in tqdm(range(0, len(prompts), batch_size), desc="Batches"):
        batch_prompts = prompts[i: i + batch_size]
        base_output = base_model.generate(batch_prompts, sampling_params)
        for j in range(len(base_output)):
            all_results.append(
                {"inputs": batch_prompts[j],
                 "output": base_output[j].outputs[0].text}
            )
            outputs.append(base_output[j].outputs[0].text)

    return test_data, all_results

@torch.inference_mode()
def main(*,
         model_name: str = "meta-llama/Llama-2-13b-hf",
         batch_size: int = 1024,
         temperature: float = 0.1,
         top_p: float = 0.9,
         tensor_parallel_size : int = 1,
         max_num_seqs : int = 256,
         max_tokens : int = 256,
         save_dir: str = "outputs/cnn",
         icl: int = 0):

    icl_type = False if icl == 0 else True
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    base_model = LLM(model=model_name, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.75, enforce_eager=True, max_num_seqs=max_num_seqs)
    test_data, all_results = get_cnn_output(base_model, tokenizer, max_tokens, batch_size, temperature, top_p, icl=icl_type)
    outputs = [o["output"] + "\n" for o in all_results]
    ref = [o["answer"] for o in test_data]
    rouge = Rouge()
    rouge_score = rouge.get_scores(outputs, ref)
    rouge1, rouge2, rougel = 0, 0, 0
    for i in rouge_score:
        rouge1 += i["rouge-1"]["f"] / len(rouge_score)
        rouge2 += i["rouge-2"]["f"] / len(rouge_score)
        rougel += i["rouge-l"]["f"] / len(rouge_score)
    result = {
        "rouge-1": rouge1,
        "rouge-2": rouge2,
        "rouge-l": rougel,
    }
    print(result)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    predictions = [{"input": example["question"], "answer": example["answer"], "model output": output["output"]} for example, output in zip(test_data, all_results)]
    prediction_save_path = os.path.join(save_dir, "predictions.jsonl")
    write_jsonl(prediction_save_path, predictions)
    # if model_name == "meta-llama/Llama-2-13b-hf":
    # with open(os.path.join(save_dir, 'mmlu_metrics.json'), 'w') as f:
    #     json.dump(results, f, indent=2)
    
if __name__ == "__main__":
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)
