import os
import argparse
import numpy as np
from tqdm import tqdm
import torch
import pathlib
import pandas as pd

import shared_dir

from my_utils.data_utils import save_list_to_tsv, load_list_from_tsv
from my_utils.test_utils import get_score_saving_dir_arg, get_ai_perc_threshold, get_auc, find_threshold


def _get_save_dir(dataset_name, metric):
    csv_file_name = f'{dataset_name}_{metric}_result.csv'
    save_folder = pathlib.Path(shared_dir.test_results_dir, 'summary')
    save_folder.mkdir(parents=True, exist_ok=True)
    csv_dir = save_folder.joinpath(csv_file_name)

    return csv_dir


def _save_csv_file(dataset_name, metric_name, data_dict, detector_list, row_list, tag):
    csv_dir = _get_save_dir(dataset_name + tag, metric_name)

    df_data = {'attacker': row_list}


    for detector in detector_list:
        df_data[detector] = []
        for row_name in row_list:
            cur_key = (dataset_name, detector, row_name)
            df_data[detector].append(data_dict[cur_key])

    df = pd.DataFrame(df_data)

    df.to_csv(csv_dir, index=False)
    print('Create and save to', csv_dir)





# def get_result(args):
#     save_folder, save_dir = get_score_saving_dir(args)
#     data_list = load_list_from_tsv(save_dir)
#     ai_prob_list = [float(d[1]) for d in data_list]
#
#     human_save_dir = save_dir.__str__().replace(args.attacker, 'human')
#     human_data_list = load_list_from_tsv(human_save_dir)[:len(ai_prob_list)]
#     human_ai_prob_list = [float(d[1]) for d in human_data_list]
#
#     high_threshold = find_threshold(human_ai_prob_list, 0.01)
#     high_threshold_TPR = get_ai_perc_threshold(ai_prob_list, high_threshold)
#     high_threshold_FPR = get_ai_perc_threshold(human_ai_prob_list, high_threshold)
#
#     print(f'Detector: {args.detector}\n'
#           f'Dataset: {args.dataset}-test\n'
#           f'Attack: {args.attacker}\n'
#           f'TPR {high_threshold_TPR:.2%} with threshold {high_threshold}\n'
#           f'FPR {high_threshold_FPR:.2%} with threshold {high_threshold}\n')
#
#     low_threshold = find_threshold(ai_prob_list, 0.85)
#     low_threshold_TPR = get_ai_perc_threshold(ai_prob_list, low_threshold)
#     low_threshold_FPR = get_ai_perc_threshold(human_ai_prob_list, low_threshold)
#
#     print(f'TPR {low_threshold_TPR:.2%} with threshold {low_threshold}\n'
#           f'FPR {low_threshold_FPR:.2%} with threshold {low_threshold}\n')
#
#     y_true = np.array([1 for _ in range(len(ai_prob_list))] + [0 for _ in range(len(human_ai_prob_list))])
#     y_score = ai_prob_list + human_ai_prob_list
#     auc = get_auc(y_true, y_score)
#
#     print(f'AUC: {auc:.4f}')

def get_auc_by_ai_human_score(ai_score_list, human_score_list):
    y_true = np.array([1 for _ in range(len(ai_score_list))] + [0 for _ in range(len(human_score_list))])
    y_score = ai_score_list + human_score_list
    auc = get_auc(y_true, y_score)

    return auc

def get_results_multiple(args, detector_list, attacker_list, tag=''):
    low_FPR_value = 0.05
    high_TPR_value = 0.9
    dataset_name = args.dataset

    auc_result_dict = {} # {(dataset, detector, row_name) : value}, row_name in [orig + attackers + human]
    low_FPR_result_dict = {}  # {(dataset, detector, row_name) : value}, row_name in [orig + attackers + human + threshold]
    high_TPR_result_dict = {}  # {(dataset, detector, row_name) : value}, row_name in [orig + attackers + human + threshold]
    for detector in detector_list:
        args.detector = detector

        gen_flag = False
        if 'gen' in args.dataset:
            gen_flag = True
            old_dataset = args.dataset
            print(old_dataset)
            args.dataset = old_dataset.split('-')[0]
            print(args.dataset)

        # load orig ai text
        args.attacker = 'orig'
        save_folder, save_dir = get_score_saving_dir_arg(args)
        data_list = load_list_from_tsv(save_dir)
        orig_ai_score_list = [float(d[1]) for d in data_list]

        # load human text
        args.attacker = 'human'
        save_folder, save_dir = get_score_saving_dir_arg(args)
        data_list = load_list_from_tsv(save_dir)
        human_ai_score_list = [float(d[1]) for d in data_list]

        # get thresholds
        low_FPR_threshold = find_threshold(human_ai_score_list, low_FPR_value)
        high_TPR_threshold = find_threshold(orig_ai_score_list, high_TPR_value)
        # assert low_FPR_threshold >= high_TPR_threshold

        low_FPR_result_dict[(dataset_name, detector, 'threshold')] = low_FPR_threshold
        high_TPR_result_dict[(dataset_name, detector, 'threshold')] = high_TPR_threshold


        # human text
        human_low_FPR_ai_perc = get_ai_perc_threshold(human_ai_score_list, low_FPR_threshold)
        human_high_TPR_ai_perc = get_ai_perc_threshold(human_ai_score_list, high_TPR_threshold)

        low_FPR_result_dict[(dataset_name, detector, 'human')] = human_low_FPR_ai_perc
        high_TPR_result_dict[(dataset_name, detector, 'human')] = human_high_TPR_ai_perc

        # orig text
        orig_low_FPR_ai_perc = get_ai_perc_threshold(orig_ai_score_list, low_FPR_threshold)
        orig_high_TPR_ai_perc = get_ai_perc_threshold(orig_ai_score_list, high_TPR_threshold)

        low_FPR_result_dict[(dataset_name, detector, 'orig')] = orig_low_FPR_ai_perc
        high_TPR_result_dict[(dataset_name, detector, 'orig')] = orig_high_TPR_ai_perc

        orig_auc = get_auc_by_ai_human_score(orig_ai_score_list, human_ai_score_list)
        auc_result_dict[(dataset_name, detector, 'orig')] = orig_auc

        if gen_flag:
            args.dataset = old_dataset


        # get baseline metrics
        for attacker in attacker_list:
            args.attacker = attacker
            save_folder, save_dir = get_score_saving_dir_arg(args)
            data_list = load_list_from_tsv(save_dir)
            cur_ai_score_list = [float(d[1]) for d in data_list]

            low_FPR_ai_perc = get_ai_perc_threshold(cur_ai_score_list, low_FPR_threshold)
            high_TPR_ai_perc = get_ai_perc_threshold(cur_ai_score_list, high_TPR_threshold)

            cur_key = (dataset_name, detector, attacker)

            low_FPR_result_dict[cur_key] = low_FPR_ai_perc
            high_TPR_result_dict[cur_key] = high_TPR_ai_perc

            cur_auc = get_auc_by_ai_human_score(cur_ai_score_list, human_ai_score_list)
            auc_result_dict[cur_key] = cur_auc

    _save_csv_file(dataset_name, 'auc', auc_result_dict, detector_list, row_list=['orig'] + attacker_list, tag=tag)

    _save_csv_file(dataset_name, 'AIperc_lowFPR', low_FPR_result_dict, detector_list,
                   row_list=['orig'] + attacker_list + ['human', 'threshold'], tag=tag)

    _save_csv_file(dataset_name, 'AIperc_highTPR', high_TPR_result_dict, detector_list,
                   row_list=['orig'] + attacker_list + ['human', 'threshold'], tag=tag)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--dataset", required=True, type=str, choices=["squad", "eli5", "yelp"],
            help="dataset to be tested")
    parser.add_argument('--task', type=str, default='essay', choices=['essay', 'qa', 'rev-gen', 'paraphrase'])
    parser.add_argument("--attacker", required=True, type=str,
                        help="attack method")
    parser.add_argument("--detector", required=True, type=str, choices=["chatdetect", "openai", "gptzero", "detectgpt"],
            help="dataset to be tested")
    parser.add_argument("--data-dir", required=True, type=str,
                        help="attack method")

    parser.add_argument("--tag", default='', type=str,
                        help="tag of run")


    args = parser.parse_args()

    attacker_list = ['parrot_paraphrase', 'dipper', 'chatgpt_paraphrase', 'human_prompt', 'SICO-chatgpt-chatdetect-paraphrase']
    detector_list = ['chatdetect', 'gpt2detect', 'logrank', 'openai', 'detectgpt', 'gptzero']

    get_results_multiple(args, detector_list, attacker_list, args.tag)
