import random

import numpy as np
import pandas as pd
import torch
import os
import glob
import argparse
import matplotlib.pyplot as plt
import json
import jsonlines
from tqdm import tqdm
from model import load_tokenizer, load_model
from fast_detect_gpt import get_sampling_discrepancy_analytic

# estimate the probability according to the distribution of our test results on ChatGPT and GPT-4
class ProbEstimator:
    def __init__(self, args):
        self.real_crits = []
        self.fake_crits = []
        for result_file in glob.glob(os.path.join("./AIDetection/fast-detect-gpt/local_infer_ref", '*.json')):
            with open(result_file, 'r') as fin:
                res = json.load(fin)
                self.real_crits.extend(res['predictions']['real'])
                self.fake_crits.extend(res['predictions']['samples'])
        print(f'ProbEstimator: total {len(self.real_crits) * 2} samples.')


    def crit_to_prob(self, crit):
        # plt.figure("人类文本分布")
        # n, bin, patches = plt.hist(self.real_crits)
        # plt.show()
        # print(self.real_crits)
        offset = np.sort(np.abs(np.array(self.real_crits + self.fake_crits) - crit))[100]
        cnt_real = np.sum((np.array(self.real_crits) > crit - offset) & (np.array(self.real_crits) < crit + offset))
        cnt_fake = np.sum((np.array(self.fake_crits) > crit - offset) & (np.array(self.fake_crits) < crit + offset))
        # print(cnt_real)
        # print(cnt_fake)
        return cnt_fake / (cnt_real + cnt_fake)

# run interactive local inference

def data_process(data_path):
    human_text = pd.read_csv(data_path, sep='\t', header=0)
    human_text = pd.DataFrame.from_dict(human_text)
    # for txt in human_text['content']:
    #     if type(txt) == float:
    #         print("error")
    human_text_shuffle = human_text.sample(frac=1)
    # print(human_text["content"])
    # print(human_text.sample(frac=1))
    # print(random.shuffle(human_text))
    # print(human_text["content"][1])
    return human_text_shuffle["content"]


def run(args):
    # human_text = data_process("./AIDetection/fast-detect-gpt/imdb_data/imdb62.tsv")
    # print(human_text[4206])
    # print(len(human_text))

    # load model
    # scoring_tokenizer = load_tokenizer(args.scoring_model_name, args.dataset, args.cache_dir)
    # scoring_tokenizer = load_tokenizer("./Model/GPT2-xl", args.dataset, args.cache_dir)
    ##使用Falcon提升性能
    scoring_tokenizer = load_tokenizer("./Model/falcon-7b-instruct", args.dataset, args.cache_dir)
    # scoring_model = load_model(args.scoring_model_name, args.device, args.cache_dir)
    # scoring_model = load_model("./Model/GPT2-xl", args.device, args.cache_dir)
    scoring_model = load_model("./Model/falcon-7b-instruct", args.device, args.cache_dir)
    scoring_model.eval()
    if args.reference_model_name != args.scoring_model_name:
            # reference_tokenizer = load_tokenizer("./Model/gpt-j-6b", args.dataset, args.cache_dir)
            # reference_model = load_model("./Model/gpt-j-6b", args.device, args.cache_dir)
            reference_tokenizer = load_tokenizer("./Model/falcon-7b", args.dataset, args.cache_dir)
            reference_model = load_model("./Model/falcon-7b", args.device, args.cache_dir)
            reference_model.eval()
    # evaluate criterion
    # name = "sampling_discrepancy_analytic"
    criterion_fn = get_sampling_discrepancy_analytic
    # prob_estimator = ProbEstimator(args)
    # input text
    # print('Local demo for Fast-DetectGPT, where the longer text has more reliable result.')
    # print('')
    # while True:
    #     print("Please enter your text: (Press Enter twice to start processing)")
    #     lines = []
    #     while True:
    #         line = input()
    #         if len(line) == 0:
    #             break
    #         lines.append(line)
    #     text = "\n".join(lines)
    #     if len(text) == 0:
    #         break
        # evaluate text
    # for attack_method in ["insert", "delete", "replace"]:
    #     for percentage in ["1%", "3%", "5%"]:
    ori_text = []
    predictions = []
    all_length = []
    max_length = 1024
    # with jsonlines.open("./AIDetection/fast-detect-gpt/scores/DSB_fastdetectGPT_machine_test_x1.jsonl", "w") as json_file:
    with open(f"./AIDetection/Related_dataset/DetectRL/DetectRL_multillm_machine_test.json", 'r+', encoding='utf-8') as file:
        data = json.load(file)
        # human_text_list = data.get("human_text", [])
        human_text_list = data.get("machine_text", [])

        for each_text in tqdm(human_text_list):
            # text_label = "human"
            # text_label = "machine"
            if type(each_text) != float:

                tokenized = scoring_tokenizer(each_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
                length = tokenized['input_ids'].size(1)
                ori_text.append(each_text)
                all_length.append(length)
                if tokenized['input_ids'].size(1) > max_length:
                    tokenized['input_ids'] = tokenized['input_ids'][:, :max_length]
                    tokenized['attention_mask'] = tokenized['attention_mask'][:, :max_length]
                labels = tokenized.input_ids[:, 1:]
                with torch.no_grad():
                    logits_score = scoring_model(**tokenized).logits[:, :-1]
                    if args.reference_model_name == args.scoring_model_name:
                        logits_ref = logits_score
                    else:
                        tokenized = reference_tokenizer(each_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
                        if tokenized['input_ids'].size(1) > max_length:
                            tokenized['input_ids'] = tokenized['input_ids'][:, :max_length]
                            tokenized['attention_mask'] = tokenized['attention_mask'][:, :max_length]
                        assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
                        logits_ref = reference_model(**tokenized).logits[:, :-1]
                    crit = criterion_fn(logits_ref, logits_score, labels)
                    predictions.append(crit)
            # break
                        # out_dict_temp = {
                        #     "content": each_text,
                        #     "score": crit,
                        #     "label": text_label,
                        #     "length": length
                        # }
                        # json_file.write(out_dict_temp)
    result = {
            "text": ori_text,
            "predictions": predictions,
            "length": all_length
        }

    with open(f'./AIDetection/DNA-DetectLLM/scores/DetectRL_multillm_fastdetectGPT_machine_test.json',
              'w') as out1:
        json.dump(result, out1, indent=4)

            # machine machine machine !!!
            # data = json.load(file)
            # machine_text_list = data.get("machine_text", [])
            # for each_text in tqdm(machine_text_list):
            #     text_label = "machine"
            #     if type(each_text) != float:
            #         tokenized = scoring_tokenizer(each_text, return_tensors="pt", padding=True,
            #                                       return_token_type_ids=False).to(args.device)
            #         token_length = tokenized['input_ids'].shape[1]
            #         if tokenized['input_ids'].size(1) > max_length:
            #             tokenized['input_ids'] = tokenized['input_ids'][:, :max_length]
            #             tokenized['attention_mask'] = tokenized['attention_mask'][:, :max_length]
            #         labels = tokenized.input_ids[:, 1:]
            #         with torch.no_grad():
            #             logits_score = scoring_model(**tokenized).logits[:, :-1]
            #             if args.reference_model_name == args.scoring_model_name:
            #                 logits_ref = logits_score
            #             else:
            #                 tokenized = reference_tokenizer(each_text, return_tensors="pt", padding=True,
            #                                                 return_token_type_ids=False).to(args.device)
            #                 if tokenized['input_ids'].size(1) > max_length:
            #                     tokenized['input_ids'] = tokenized['input_ids'][:, :max_length]
            #                     tokenized['attention_mask'] = tokenized['attention_mask'][:, :max_length]
            #                 assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
            #                 logits_ref = reference_model(**tokenized).logits[:, :-1]
            #             crit = criterion_fn(logits_ref, logits_score, labels)
            #             out_dict_temp = {
            #                 "content": each_text,
            #                 "score": crit,
            #                 "label": text_label,
            #                 "length": token_length
            #             }
            #             json_file.write(out_dict_temp)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--reference_model_name', type=str, default="neo")  # use gpt-j-6B for more accurate detection
    parser.add_argument('--scoring_model_name', type=str, default="gpt-j-6B")
    parser.add_argument('--dataset', type=str, default="xsum")
    parser.add_argument('--ref_path', type=str, default="./local_infer_ref")
    parser.add_argument('--device', type=str, default="cuda:5")
    parser.add_argument('--cache_dir', type=str, default="../cache")
    args = parser.parse_args()

    run(args)