import json
import numpy as np
import params
from utils.prompt_template import PromptTemplateLoader
from datasets import load_dataset
from utils.evaluation_api import gemini_evaluate_batch
import copy

prompt_loader = PromptTemplateLoader(template_dir_path="prompt_templates/")

def load_wiki(split, max_sample_num=100):
    if split == 'test':
        ds = load_dataset("microsoft/wiki_qa")["test"]
    elif split == 'validation':
        ds = load_dataset("microsoft/wiki_qa")['validation']
    else:
        raise Exception(f'Split {split} note supported')
    
    dataset = []
    index = -1
    
    unique_qid = np.unique(ds['question_id'])
    for qid in unique_qid[:max_sample_num]:
        samples = [i for i in ds if i['question_id'] == qid]
        index += 1
        
        dataset.append({
            "question": samples[0]['question'],
            "correct_answers": [sample['answer'] for sample in samples if sample['label'] == 1],
            "incorrect_answers": [sample['answer'] for sample in samples if sample['label'] == 0]
        })
            
    return dataset


def load_truthfulqa(split, max_sample_num=817):
    ori_dataset = load_dataset("truthfulqa/truthful_qa", "generation")["validation"]
    dataset = []
    index = -1
    for sample in ori_dataset:
        index += 1
        question = sample["question"]
        best_answer = sample["best_answer"]
        correct_answer = sample['correct_answers']
        incorrect_answers = sample['incorrect_answers']
        dataset.append({
            "question": question,
            "correct_answers": [best_answer] + correct_answer,
            "incorrect_answers": incorrect_answers
        })
    if split == "test":
        dataset = dataset[400:817]
    elif split == "validation":
        dataset = dataset[:400]
    else:
        raise ValueError(f"Split {split} not supported")
    if max_sample_num is not None:
        dataset = dataset[:max_sample_num]
    return dataset


def eval_truthful(data, args):
    messages = []
    prompt_key = "eval_truth"
    for sample in data:
        response = sample["response"].strip()
        response = response[response.rfind("\n") + 1:]
        if response.lower().startswith("**refined answer:**"):
            response = response[len("**Refined Answer:**"):]

        template_placeholder = {"question": sample["question"],
                                "correct_answers": "\n".join(sample["correct_answers"]),
                                "incorrect_answers": "\n".join(sample["incorrect_answers"]),
                                "generated_answer": response}
        prompt = prompt_loader.construct_prompt(prompt_key, template_placeholder)
        messages.append([{"role": "user", "content": prompt}])

    if args.evaluation_type == "gemini":
        eval_results = gemini_evaluate_batch(
            message_list=messages, 
            gemini_client=args.gemini_client
        )          
    else:
        raise ValueError(f"Model {args.evaluation_type} not supported")
    
    cnt = {"correct": 0, "wrong": 0}
    results = {}
    for i in range(len(data)):
        eval_result = eval_results[i].strip().lower()
        # remove all the characters in the beginning of the string that are not letters
        for j in range(len(eval_result)):
            if eval_result[j].isalpha():
                eval_result = eval_result[j:]
                break
        
        if args.save_results:
            results[i] = {
                'index': i,
                'sample': data[i],
                'eval_result': eval_result,
                'eval_type': 'truth',
                'message': messages[i]
            }

        if eval_result.startswith("correct"):
            cnt["correct"] += 1
        elif eval_result.startswith("wrong"):
            cnt["wrong"] += 1
    total_valid = cnt["correct"] + cnt["wrong"]

    if args.save_results:
        file_name = f'{args.base_model}_{args.eval_data}_{args.method}_{args.evaluation_type}_{args.data_split}_truth.json'
        with open(f"{params.analysis_dir}/{file_name}", "w") as f:
            json.dump(results, f, indent=4)

    return cnt['correct'] / total_valid, total_valid


def eval_informativeness(data, args):

    messages = []
    prompt_key = "eval_info"
    for sample in data:
        template_placeholder = {"question": sample["question"],
                                "answer": sample["response"].strip()}
        prompt = prompt_loader.construct_prompt(prompt_key, template_placeholder)
        messages.append([{"role": "user", "content": prompt}])
    
    if args.evaluation_type == "gemini":
        eval_results = gemini_evaluate_batch(
            message_list=messages, 
            gemini_client=args.gemini_client
        )
    else:
        raise ValueError(f"Model {args.evaluation_type} not supported")
    
    cnt = {"complete": 0, "incomplete": 0}
    results = {}
    for i in range(len(data)):
        eval_result = eval_results[i].strip().lower()
        # remove all the characters in the beginning of the string that are not letters
        for j in range(len(eval_result)):
            if eval_result[j].isalpha():
                eval_result = eval_result[j:]
                break

        if args.save_results:
            results[i] = {
                'sample': data[i],
                'eval_result': eval_result,
                'eval_type': 'infor',
                'message': messages[i]
            }

        if eval_result.startswith("yes"):
            cnt["complete"] += 1
        elif eval_result.startswith("no"):
            cnt["incomplete"] += 1
    total_valid = cnt["complete"] + cnt["incomplete"]

    if args.save_results:
        file_name = f'{args.base_model}_{args.method}_{args.evaluation_type}_{args.data_split}_infor.json'
        with open(f"{params.analysis_dir}/{file_name}", "w") as f:
            json.dump(results, f, indent=4)

    return cnt['complete'] / total_valid, total_valid


def eval_single_file(file_path, args):

    response_data = json.load(open(file_path, "r"))
    data_len = len(response_data)
    data = load_wiki(split=args.data_split, max_sample_num=data_len)

    for i in range(data_len):
        data[i]["response"] = response_data[i]["generated_answer"]

    truthful_score, truth_valid = eval_truthful(data=data, args=args)
    info_score, info_valid = eval_informativeness(data=data, args=args)
    t_times_i = truthful_score * info_score
    
    return {
        "data_split": args.data_split,
        "sample_num": args.max_sample_num,
        "model": args.base_model,
        "method": args.method,
        "truthful": truthful_score, 
        "informativeness": info_score, 
        "t_times_i": t_times_i, 
        "truth_valid": truth_valid,
        "info_valid": info_valid,
        "evaluation_type": args.evaluation_type, 
        "eval_data_path": args.eval_data_path,
        "train_data": args.train_data,
        "eval_data": args.eval_data,
        "standard_prompt_key": args.standard_prompt_key,
        "max_new_tokens": args.max_new_tokens,
        }
    
    
BIOGRAPHY_DATA_PATH = "data/article_200.json"
def load_biography(split, max_sample_num=128):

    with open(BIOGRAPHY_DATA_PATH, "r") as f:
        data = json.load(f)
    names = list(data.keys())
    dataset = []
    for name in names:
        dataset.append({
            "question": name,
            "answer": data[name]
        })
    if split=="test":
        dataset = dataset[100:]
    elif split == "validation":
        dataset = dataset[:100]
    else:
        raise ValueError(f"Split {split} not supported")
    if max_sample_num is not None:
        dataset = dataset[:max_sample_num]
    return dataset

def parse_bullets(sentence):
    sentence = sentence.replace('*','')
    if sentence.find("Question 2") != -1:
        sentence = sentence[sentence.find("Question 2"):]
    if sentence.lower().find("refined answer") != -1:
        sentence = sentence[sentence.lower().find("refined answer") + len("refined answer"):]
        sentence = sentence[sentence.find("\n") + 1:]
        # print(sentence)
    bullets_preprocess = [l for l in sentence.split("\n") if len(l) > 0]
    if len(bullets_preprocess) == 1:
        bullets_preprocess = [l+'.' for l in sentence.split(".") if len(l) > 0]
    new_bullets_preprocess = []
    for bullet in bullets_preprocess:
        if (bullet.find("Here is") != -1 or
                bullet[-1] == ':' or
                bullet.find("Here are") != -1 or
                bullet.find("I apologize") != -1 or
                bullet.find("Sorry") != -1 or
                bullet.find("Thank you") != -1 or
                bullet.find("Please") != -1 or
                bullet.find("I hope")) != -1:
            continue
        new_bullets_preprocess.append(bullet)
    bullets_preprocess = new_bullets_preprocess

    bullets = []

    for bullet in bullets_preprocess:
        try:
            idx = bullet.find(next(filter(str.isalpha, bullet)))
        except:
            continue

        bullet = bullet[idx:]

        if len(bullet) != 0:
            bullets.append(bullet)

    return bullets


def parse_yes_no(string):
    """
    Parses a string containing "yes" or "no" and returns a boolean value.

    Args:
        string (str): The string to parse.

    Returns:
        bool: True if the string contains "yes", False if the string contains "no".

    Raises:
        ValueError: If the input string does not contain "yes" or "no".
    """

    if "yes" in string.lower():
        return True
    elif "no" in string.lower():
        return False
    else:
        return None


def filter_people(person):
    people = person.split("(")[0]
    return people


def bio_eval_single_file(eval_data_path, args):
    print(eval_data_path)
    response_data = json.load(open(eval_data_path, "r"))
    print(f"Evaluating {eval_data_path}")
    with open("data/article.json", "r") as f:
        gt_data = json.load(f)

    gt_data_filter = {}

    for k, v in gt_data.items():
        k = filter_people(k)
        gt_data_filter[k] = v

    gt_data = gt_data_filter

    accuracies = []
    messages = []
    total_gt_bullets = 0
    bio_bullets_list = []
    
    eval_result = {}
    for eval_metric in ["informativeness", "truthfulness"]:
        
        for index, sample in enumerate(response_data):
            person = sample["question"]
            if person.find("(") != -1:
                person = person.split("(")[0]
            response = sample["generated_answer"]
            if person not in gt_data:
                continue
            gt_description = gt_data[person]
            gt_bullets = parse_bullets(gt_description)
            total_gt_bullets += len(gt_bullets)
            gt_bullets = " ".join(gt_bullets)
            bio_description = response
            bio_bullets = parse_bullets(bio_description)[:5]
            bio_bullets_list.append(bio_bullets)

            for bullet in bio_bullets:
                if eval_metric == "informativeness":
                    # --------------------------
                    # For evaluate Informativeness
                    # --------------------------
                    template_placeholder = {
                        "question": f'Please list major achievements or contributions of {person}',
                        "answer": bullet
                        }
                    
                    prompt = prompt_loader.construct_prompt('eval_info', template_placeholder)
                    message = [{"role": "user", "content": prompt}]            
                elif eval_metric == "truthfulness":
                    # --------------------------
                    # For evaluate Truthfulness
                    # --------------------------
                    message = [{"role": "user",
                                "content": "Reference: \n {} \n\n Based on the above reference and your own knowledge about the computer scientist {}, is the following statement about the achievement made by this computer scientist correct and factual? \n {} \n Give a single word answer, yes or no. ".format(
                                    gt_bullets, person, bullet)}]
                else:
                    raise ValueError(f"Unsupported evaluation metric: {eval_metric}")
                
                messages.append(copy.deepcopy(message))

        eval_results = gemini_evaluate_batch(message_list=messages, gemini_client=args.gemini_client)

        true_positive = 0
        for _, result in enumerate(eval_results):
            accurate = parse_yes_no(result)

            if accurate is not None:
                accuracies.append(accurate)
                if accurate:
                    true_positive += 1
            else:
                accuracies.append(False)
        sample_eval_results = []
        index = 0
        for bio_bullets in bio_bullets_list:
            bullet_num = len(bio_bullets)
            tmp = accuracies[index:index + bullet_num]
            if bullet_num== 0:
                sample_eval_results.append({
                    "correct_num": 0,
                    "incorrect_num": 0,
                })
            else:
                sample_eval_results.append({
                    "correct_num": sum(tmp),
                    "incorrect_num": bullet_num - sum(tmp),
                })
            index = index + bullet_num

        avg_correct_num = np.mean([sample["correct_num"] for sample in sample_eval_results])
        avg_incorrect_num = np.mean([sample["incorrect_num"] for sample in sample_eval_results])
        avg_accuracy = np.mean(accuracies)
        print(f"eval_metric: {eval_metric}, correct_num: {avg_correct_num}, incorrect_num: {avg_incorrect_num}, accuracy: {avg_accuracy}")
        
        eval_result[eval_metric] = {
            "correct_num": avg_correct_num,
            "incorrect_num": avg_incorrect_num,
            "accuracy": avg_accuracy,
            "total_gt_bullets": total_gt_bullets,
        }
        
    return {
        "data_split": args.data_split,
        "sample_num": args.max_sample_num,
        "model": args.base_model,
        "method": args.method,
        "truthfulness": eval_result["truthfulness"], 
        "informativeness": eval_result["informativeness"], 
        "evaluation_type": args.evaluation_type, 
        "eval_data_path": args.eval_data_path,
        "train_data": args.train_data,
        "eval_data": args.eval_data,
        "standard_prompt_key": args.standard_prompt_key,
        "max_new_tokens": args.max_new_tokens,
    }