import collections
import os
import pandas as pd
from filter.Guard import *
import csv
import torch
import tqdm
from ratelimit import limits
import argparse

policy = ("<BEGIN UNSAFE CATEGORIES> \n"
            "S1: Violent Crimes \n"
            "S2: Non-Violent Crimes \n"
            "S3: Sex-Related Crimes \n"
            "S4: Hate Speech\n"
            "S5: Suicide & Self-Harm \n"
            "S6: Sexual Content \n"
            "<END UNSAFE CATEGORIES> \n")



class EvaluationScore:
    def __init__(self, input_scores, sample_num, method_type, method_used, output, time, toxic_categories, time_per_token):
        self.scores = {}
        for guard_scores in input_scores:
            if guard_scores == ['']:
                continue
            if len(guard_scores) < 2:
                continue
            self.scores[guard_scores[0]] = float(guard_scores[1])
        self.sample_num = sample_num
        self.method_type = method_type
        self.is_method_used = method_used
        self.output = output
        self.time = time
        self.toxic_categories = toxic_categories
        self.time_per_token = time_per_token


    def add_scores(self, added_scores):
        for score in added_scores:
            self.scores[score[0]] = float(score[1])

    def to_dict(self):
        # output a dict containing the sample num and methodology details, and include the actual scores

        out_dict = {'sample_num': self.sample_num, 'method_type': self.method_type, 'is_method_used': self.is_method_used,
                    'output': self.output, 'time': self.time, 'time_per_token': self.time_per_token}
        out_dict.update(self.toxic_categories)
        out_dict.update(self.scores)

        return out_dict


    def __str__(self):
        """
        Write the scores to a string a way that is easy to read later
        :return: None
        """
        str = ""
        for key, value in self.scores.items():
            str += f"{key}: {value}\n"
        return str

    @staticmethod
    def from_txt(str, sample_num, method_type, is_method_used, output, time, category_dict, time_per_token):
        """
        Convert a string to an EvaluationScore object
        :param str: the string to convert
        :return: None
        """
        scores = str.split('\n')
        scores = [score.split(': ') for score in scores]
        return EvaluationScore(scores, sample_num, method_type, is_method_used, output, time, category_dict, time_per_token)

    def to_txt(self, filename):
        with open (filename, 'w') as f:
            f.write(str(self))

class EvaluationAccumulator:
    def __init__(self, guards):
        # initialize three-level nested dictionary to hold the scores
        self.scores = {}
        self.guards = guards
        self.example_score = None

    def add_score(self, eval_score: EvaluationScore):
        # check if the sample number is already in the dictionary
        if eval_score.sample_num not in self.scores:
            self.scores[eval_score.sample_num] = {}
        # check if the method type is already in the dictionary
        if eval_score.method_type not in self.scores[eval_score.sample_num]:
            self.scores[eval_score.sample_num][eval_score.method_type] = {}
        # check if the method used is already in the dictionary
        if eval_score.is_method_used not in self.scores[eval_score.sample_num][eval_score.method_type]:
            self.scores[eval_score.sample_num][eval_score.method_type][eval_score.is_method_used] = eval_score
            if self.example_score is None:
                self.example_score = eval_score

    def classify_sample(self, content, guards, sample_num, method_type, is_method_used, output, time, category_dict, time_per_token):
        score_list = []
        for guard in guards:
            scores = guard.get_scores(content)
            scores = [list(score) for score in scores]
            score_list.extend(scores)
        eval_score = EvaluationScore(score_list, sample_num, method_type, is_method_used, output, time, category_dict, time_per_token)
        self.add_score(eval_score)
        return eval_score


    def to_csv(self, filename):
        eval_fields = list(self.example_score.__dict__.keys())
        eval_fields.remove('scores')
        eval_score_fields = list(self.example_score.scores.keys())
        eval_toxic_categories = list(self.example_score.toxic_categories.keys())
        fields = eval_fields + eval_score_fields + eval_toxic_categories
        with open(filename, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=fields)
            writer.writeheader()
            for sample_num in self.scores:
                for method_type in self.scores[sample_num]:
                    for is_method_used in self.scores[sample_num][method_type]:
                        sample_dict = self.scores[sample_num][method_type][is_method_used].to_dict()
                        # remove any fields that are not in the fields list
                        for key in list(sample_dict.keys()):
                            if key not in fields:
                                del sample_dict[key]
                        writer.writerow(sample_dict)



def get_num_missing_scores(root_directory):
    missing_scores = 0
    with tqdm.tqdm(total=7200) as pbar:
        for dirpath, dirnames, filenames in os.walk(root_directory):
            pbar.update(1)
            for file in filenames:
                if file in ['output.txt']:
                   if not os.path.exists(os.path.join(dirpath, 'scores.txt')):
                       missing_scores += 1
    print(f"Number of missing scores: {missing_scores}")
    return missing_scores

def convert_config_str_to_dict(config_str):
    config_dict = {}
    for line in config_str.split('\n'):
        if line == '':
            continue
        key, value = line.split('-')[0], line.split('-')[1]
        config_dict[key] = value
    return config_dict


def get_content(file_path, split_output=False):
    file = file_path.split('/')[-1]
    if file == 'output.txt':
        content = open(file_path, 'r').read()
        if split_output:
            content_list = content.split('========================')
            prompt = content_list[-3]
            answer = content_list[-2]
            return prompt, answer
    elif file == 'time.txt':
        lines = open(file_path, 'r').readlines()

        total_time = float(lines[0].split(': ')[1].replace('\n', ''))
        if len(lines) > 1:
            time_per_token = float(lines[1].split(': ')[1].replace('\n', ''))
        else:
            time_per_token = 0
        return total_time, time_per_token

    elif file == 'configuration.txt':
        content = open(file_path, 'r').read()
        content = convert_config_str_to_dict(content)
    elif file == 'scores.txt':
        content = open(file_path, 'r').read()
    else:
        content = open(file_path, 'r').read()
    return content

def del_scores(root_directory):
    with tqdm.tqdm(total=8000) as pbar:
        for dirpath, dirnames, filenames in os.walk(root_directory):
            pbar.update(1)
            for file in filenames:
                if file == 'scores.txt':
                    os.remove(os.path.join(dirpath, file))

def del_rain(root_directory):
    with tqdm.tqdm(total=8000) as pbar:
        for dirpath, dirnames, filenames in os.walk(root_directory):
            pbar.update(1)
            # delete all files in the sub-directoriy /.../num/rain
            if 'rain' in dirpath:
                for file in filenames:
                    os.remove(os.path.join(dirpath, file))
                os.removedirs(dirpath)

def get_file_indexes(indexing_system="new"):
    if indexing_system == 'old':
        sample_num_index = -1
        method_type_index = -3
        is_method_index = -2
    elif indexing_system == 'new':
        sample_num_index = -2
        method_type_index = -4
        is_method_index = -1
    else:
        sample_num_index = -2
        method_type_index = -4
        is_method_index = -3
    return sample_num_index, method_type_index, is_method_index

def get_sample_details(dirpath, sample_num_index, model_type_index, is_method_index):
    sample_num = dirpath.split('/')[sample_num_index]
    model_type = dirpath.split('/')[model_type_index]
    is_method_used = dirpath.split('/')[is_method_index]
    time, time_per_token = get_content(os.path.join(dirpath, 'time.txt'))
    output = get_content(os.path.join(dirpath, 'output.txt'))
    return sample_num, model_type, is_method_used, output, time, time_per_token

def fill_category_dict(category_dict, dirpath):
    with open(os.path.join(dirpath, 'category.txt'), 'r') as f:
        category = f.read()
        for line in category.split('\n')[:-1]:
            key = line.split(': ')[0]
            value = line.split(': ')[1]
            category_dict[key] = value

def remove_slash_from_category(directory):
    for dirpath, dirnames, filenames in os.walk(directory):
        if not filenames:
            continue
        if 'scores.txt' in filenames:
            with open(os.path.join(dirpath, 'scores.txt'), 'r') as f:
                scores = f.read()
            with open(os.path.join(dirpath, 'scores.txt'), 'w') as f:
                f.write(scores.replace('/', '_'))

def get_num_samples(root_directory):
    num_samples = 0
    list_of_subdirs = os.listdir(root_directory)
    num_of_subdirs = len(list_of_subdirs)
    first_subdir = list_of_subdirs[5]
    list_of_subsubdirs = os.listdir(os.path.join(root_directory, first_subdir))
    num_samples = len(list_of_subsubdirs)
    total = num_of_subdirs * num_samples
    return num_of_subdirs, num_samples, total



def add_score_to_dir(root_directory, guards):
    sample_num_index, method_type_index, is_method_index = get_file_indexes()
    total_samples = get_num_samples(root_directory)
    with tqdm.tqdm(total=total_samples) as pbar:
        for dirpath, dirnames, filenames in os.walk(root_directory):
            if not filenames or filenames == ['config.txt']:
                continue
            pbar.update(1)
            sample_num, method_type, is_method_used, output, time = get_sample_details(dirpath, sample_num_index, method_type_index, is_method_index)
            for guard in guards:
                assistant_output = output.split('========================')[2]
                new_scores = guard.get_scores(assistant_output)
            if os.path.exists(os.path.join(dirpath, 'scores.txt')):
                with open(os.path.join(dirpath, 'scores.txt'), 'r') as f:
                    existing_scores = f.read()
                if "OAI" in existing_scores:
                    continue
                with open(os.path.join(dirpath, 'scores.txt'), 'a') as f:
                    for score in new_scores:
                        category = score[0].replace('/', '_')
                        category_score = float(score[1])
                        line = f"{category}: {category_score}\n"
                        f.write(line)
            else:
                print(f"Missing score file in example: {sample_num}")


def generate_scores_truthful(root_directory, guards, csv_filename='truthful_scores.csv'):
    sample_num_index = -1
    eval_acc = {}
    score_filename = "scores.txt"
    num_subdirs, num_outputs, total = get_num_samples(root_directory)
    print(f"Scoring {num_subdirs} samples")
    with tqdm.tqdm(total=num_subdirs) as pbar:
        for dirpath, dirnames, filenames in os.walk(root_directory):
            if not filenames or not dirnames or 'best_answer.txt' not in filenames:
                continue
            pbar.update(1)
            sample_num = dirpath.split('/')[sample_num_index]
            best_answer = get_content(os.path.join(dirpath, 'best_answer.txt'))
            correct_answers = get_content(os.path.join(dirpath, 'correct_answers.txt'))
            incorrect_answers = get_content(os.path.join(dirpath, 'incorrect_answers.txt'))
            eval_acc[sample_num] = {}
            for dirname in dirnames:
                eval_acc[sample_num][dirname] = {}
                if os.path.exists(os.path.join(dirpath, dirname, 'scores.txt')):
                    with open(os.path.join(dirpath, dirname, 'scores.txt'), 'r') as f:
                        scores = f.read()
                    score_name = scores.split(':')[0]
                    score_value = scores.split(': ')[1].strip()
                    score_value = int(score_value)
                    eval_acc[sample_num][dirname] = score_value
                else:
                    prompt, answer = get_content(os.path.join(dirpath, dirname, 'output.txt'),
                                                                     split_output=True)
                    scores = []
                    for guard in guards:
                        score = guard.get_truthful_scores(prompt, answer, best_answer, correct_answers, incorrect_answers)
                        scores.append(score)
                    for score in scores:
                        eval_acc[sample_num][dirname] = score[1]
                    with open(os.path.join(dirpath, dirname, 'scores.txt'), 'w') as f:
                        for score in scores:
                            f.write(f"{score[0]}: {score[1]}\n")

    with open(csv_filename, 'w') as f:
        dataframe = pd.DataFrame(eval_acc).T
        # tag columns with the first row
        dataframe.set_axis(dataframe.iloc[0], axis=1)
        dataframe.set_axis(dataframe.index, axis=0)
        dataframe.to_csv(f)






def generate_scores_truthful_old(root_directory, guards, csv_filename='truthful_scores.csv'):
    # indexes for translating the file path to the sample number, method type, and whether the method was used
    sample_num_index, model_type_index, is_method_index = get_file_indexes()
    # accumulator object to hold all the different scores in a dict
    eval_acc = {}
    score_filename = "scores.txt"
    num_samples = get_num_samples(root_directory)
    print(f"Scoring {num_samples} samples")
    with tqdm.tqdm(total=num_samples) as pbar:
        for dirpath, dirnames, filenames in os.walk(root_directory):
            pbar.update(1)
            # skip empty/parent directories
            if not filenames or 'best_answer.txt' not in filenames or not dirnames:
                continue
            sample_num = dirpath.split('/')[sample_num_index]
            best_answer = get_content(os.path.join(dirpath, 'best_answer.txt'))
            correct_answers = get_content(os.path.join(dirpath, 'correct_answers.txt'))
            incorrect_answers = get_content(os.path.join(dirpath, 'incorrect_answers.txt'))
            output_0 = dirnames[0]
            if 'no' in output_0:
                no_method_dirname = output_0
                method_dirname = dirnames[1]
            else:
                method_dirname = output_0
                no_method_dirname = dirnames[1]
            prompt_method, answer_method = get_content(os.path.join(dirpath, method_dirname,'output.txt'), split_output=True)

            prompt_no_method, answer_no_method = get_content(os.path.join(dirpath, no_method_dirname, 'output.txt'), split_output=True)
            all_scores = []
            for guard in guards:
                scores_method = guard.get_truthful_scores(prompt_method, answer_method, best_answer, correct_answers, incorrect_answers)
                scores_no_method = guard.get_truthful_scores(prompt_no_method, answer_no_method, best_answer, correct_answers, incorrect_answers)
                scores = (scores_method, scores_no_method)
                all_scores.append(scores)
            eval_acc[sample_num] = {}
            eval_acc[sample_num][method_dirname] = {}
            eval_acc[sample_num][no_method_dirname] = {}
            for score in all_scores:
                eval_acc[sample_num][method_dirname][score[0][0]] = score[0][1]
                eval_acc[sample_num][no_method_dirname][score[1][0]] = score[1][1]

            with open(os.path.join(dirpath, method_dirname, 'scores.txt'), 'w') as f:
                for score in all_scores:
                    f.write(f"{score[0][0]}: {score[0][1]}\n")
            with open(os.path.join(dirpath, no_method_dirname, 'scores.txt'), 'w') as f:
                for score in all_scores:
                    f.write(f"{score[1][0]}: {score[1][1]}\n")
    with open(csv_filename, 'w') as f:
        pd.DataFrame(eval_acc).to_csv(f)

def check_score_for_error(dirpath):
    str = open(dirpath, 'r').read()
    scores = str.split('\n')
    scores = [score.split(': ') for score in scores]
    for score in scores:
        if score == ['']:
            continue
        if score[1] == '-1':
            return True
    return False

def generate_scores_eval_extras(root_directory, guards, csv_filename='eval_scores.csv', extras=None):
    sample_dirs = os.listdir(root_directory)
    eval_acc = EvaluationAccumulator(guards)
    score_filename = "scores.txt"
    num_samples = len(sample_dirs)
    print(f"Scoring {num_samples} samples")
    with tqdm.tqdm(total=num_samples) as pbar:
        for sample_dir in sample_dirs:
            sample_dirpath = os.path.join(root_directory, sample_dir)
            if not os.path.isdir(sample_dirpath):
                continue
            pbar.update(1)
            sample_num = int(sample_dir)
            subdirs_in_sampledir = os.listdir(sample_dirpath)
            if extras:
                missing_flag = False
                for name in extras:
                    if name not in subdirs_in_sampledir:
                        missing_flag = True
                        break
                if missing_flag:
                    continue
            for subdir in subdirs_in_sampledir:
                subdir_path = os.path.join(sample_dirpath, subdir)
                if not os.path.isdir(subdir_path):
                    continue
                if 'method' not in subdir and (extras and subdir not in extras):
                    continue
                files_in_subdir = os.listdir(subdir_path)
                if 'output.txt' not in files_in_subdir:
                    continue
                sample_num, model_type, is_method_used, output, time, time_per_token = get_sample_details(subdir_path, -2, -4, -1)
                score_filepath = os.path.join(subdir_path, score_filename)
                if os.path.exists(score_filepath) and not check_score_for_error(score_filepath):
                    # load the scores from the file
                    with open(score_filepath, 'r') as f:
                        scores_txt = f.read()

                    score = EvaluationScore.from_txt(scores_txt, sample_num, model_type, is_method_used, output, time, {}, time_per_token)
                    eval_acc.add_score(score)
                else:
                    output_content = get_content(os.path.join(subdir_path, 'output.txt'))
                    assistant_output = output_content.split('Assistant: ')[-1].replace("========================\n", '')
                    scores = eval_acc.classify_sample(assistant_output, guards, sample_num, model_type, is_method_used, output, time, {}, time_per_token)
                    scores.to_txt(os.path.join(subdir_path, score_filename))
    eval_acc.to_csv(csv_filename.replace('.csv', '_extra.csv'))


def generate_scores_eval(root_directory, guards, csv_filename='eval_scores.csv'):
    """
    Generate scores for the evaluation samples
    :param root_directory: root directory containing the samples
    :param guards: List of the evaluators to use for each sample
    :param csv_filename: output file name
    :return:
    """
    # indexes for translating the file path to the sample number, method type, and whether the method was used
    sample_num_index, model_type_index, is_method_index = get_file_indexes()
    # accumulator object to hold all the different scores in a dict
    eval_acc = EvaluationAccumulator(guards)
    score_filename = "scores.txt"
    _, _, num_samples = get_num_samples(root_directory)
    print(f"Scoring {num_samples} samples")

    with tqdm.tqdm(total=num_samples) as pbar:
        for dirpath, dirnames, filenames in os.walk(root_directory):
            pbar.update(1)
            # skip empty/parent directories
            if not filenames or 'output.txt' not in filenames:
                continue
            sample_num, model_type, is_method_used, output, time, time_per_token = get_sample_details(dirpath, sample_num_index, model_type_index, is_method_index)

            category_dict = {}
            # check if the scores already exist for this sample - if so load them from the file
            score_filepath = os.path.join(dirpath, score_filename)
            if os.path.exists(score_filepath) and not check_score_for_error(score_filepath):
                # load the scores from the file
                with open(os.path.join(dirpath, score_filename), 'r') as f:
                    scores_txt = f.read()

                score = EvaluationScore.from_txt(scores_txt, sample_num, model_type, is_method_used, output, time, category_dict, time_per_token)
                score.toxic_categories = category_dict
                eval_acc.add_score(score)
            else:
                    output_content = get_content(os.path.join(dirpath, 'output.txt'))
                    assistant_output = output_content.split('Assistant: ')[-1].replace("========================\n", '')
                    scores = eval_acc.classify_sample(assistant_output, guards, sample_num, model_type, is_method_used, output, time, category_dict, time_per_token)
                    scores.to_txt(os.path.join(dirpath, score_filename))
    eval_acc.to_csv(csv_filename)

def check_num_of_tokens_in_evaluation_output(csv_file):
    data = pd.read_csv(csv_file)
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("meta-filter/Llama-2-7b-chat-hf")
    total_tokens = 0
    for index, row in data.iterrows():
        output = row['output']
        tokens = tokenizer.tokenize(output)
        num_tokens = len(tokens)
        total_tokens += num_tokens
    print(f"Number of tokens in output: {total_tokens}")

def test_perspective_rate():
    guards = [PerspectiveGuard()]
    dir = 'eval_output_samples/mistral/'
    del_scores(dir)
    generate_scores_eval(dir, guards)

def score_all_dirs():
    base = 'eval_output_samples/'
    dirs = [base + 'final_eval/mistral', base + 'final_eval/llama2', base + 'final_eval/dolphin',
            base + 'ablation/beavertail_llama_2', base + 'ablation/beavertails',
            base + 'ablation/truthful_llama_2', base + 'ablation/truthful',
            base + 'code_eval/torch', base + 'code_eval/tensorflow'
            ]
    for dir in dirs:
        main(dir)

def main(dir = 'eval_output_samples/final_eval/llama2', extras=False):
    filename = 'eval_results_csv/' + '_'.join(dir.split('/')[1:]) + '.csv'
    guards = [OpenAImodGuard()]
    # guards = [CodeGuard()]
    # guards = [OpenAImodGuard()]
    if extras:
        generate_scores_eval_extras(dir, guards, filename, extras)
    else:
        generate_scores_eval(dir, guards, filename)
    # generate_scores_truthful(dir, guards, filename)

def parse_args():
    parser = argparse.ArgumentParser(description='Evaluate the output of the models')
    parser.add_argument('-d', '--dir', type=str, default='eval_output_samples/final_eval/llama2')
    parser.add_argument('-delete', '--delete', action='store_true')
    parser.add_argument('-extras', '--extras', nargs='*', default=[])
    args = parser.parse_args()
    return args

def del_extras(root_directory):
    print('started deleting extras')
    for dirpath, dirnames, filenames in os.walk(root_directory):
        if not filenames:
            continue

        for dir in dirnames:
            if dir in ["0.985"]:
                print('deleting', dirpath, dir)
                # delete all files in the sub-directory and then delete the sub-directory
                files = os.listdir(os.path.join(dirpath, dir))
                for file in files:
                    os.remove(os.path.join(dirpath, dir, file))
                os.removedirs(os.path.join(dirpath, dir))

def remove_imend(root_directory):
    output_end_token = "<|im_end|>"
    print("started")
    for dirpath, dirnames, filenames in os.walk(root_directory):
        if not filenames:
            continue
        for file in filenames:
            if file == "output.txt":
                with open(os.path.join(dirpath, file), 'r') as f:
                    output = f.read()
                if output_end_token in output:
                    split_output = output.split(output_end_token)
                    new_output = split_output[0]
                    new_output += "\n========================"
                    with open(os.path.join(dirpath, file), 'w') as f:
                        f.write(new_output)

if __name__ == "__main__":
    args = parse_args()
    if args.delete:
        # del_scores(args.dir)
        del_extras(args.dir)
    # remove_imend(args.dir)
    main(args.dir, args.extras)
