import os.path

import numpy as np
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
import re
import torch
import tqdm
import argparse
import json
from data_builder import load_data, save_data
from metrics import get_roc_metrics, get_precision_recall_metrics
from model import load_tokenizer, load_model, get_model_fullname, from_pretrained

# define regex to match all <extra_id_*> tokens, where * is an integer
pattern = re.compile(r"<extra_id_\d+>")

def load_mask_model(model_name, device, cache_dir):

    model_name = get_model_fullname(model_name)
    # mask filling t5 model
    print(f'Loading mask filling model {model_name}...')
    mask_model = from_pretrained(AutoModelForSeq2SeqLM, model_name, {}, cache_dir)
    mask_model = mask_model.to(device)
    return mask_model

def load_mask_tokenizer(model_name, max_length, cache_dir):
    model_name = get_model_fullname(model_name)
    tokenizer = from_pretrained(AutoTokenizer, model_name, {'model_max_length': max_length}, cache_dir)
    return tokenizer

def tokenize_and_mask(text, span_length, pct, ceil_pct=False):
    buffer_size = 1
    tokens = text.split(' ')
    # tmp_tokens = text.split(' ')
    # if len(tmp_tokens) > 350:
    #     tokens = tmp_tokens[0:350]
    # else:
    #     tokens = tmp_tokens
    mask_string = '<<<mask>>>'

    n_spans = pct * len(tokens) / (span_length + buffer_size * 2)
    if ceil_pct:
        n_spans = np.ceil(n_spans)
    n_spans = int(n_spans)

    n_masks = 0
    while n_masks < n_spans:
        start = np.random.randint(0, len(tokens) - span_length)
        end = start + span_length
        search_start = max(0, start - buffer_size)
        search_end = min(len(tokens), end + buffer_size)
        if mask_string not in tokens[search_start:search_end]:
            tokens[start:end] = [mask_string]
            n_masks += 1

    # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
    num_filled = 0
    for idx, token in enumerate(tokens):
        if token == mask_string:
            tokens[idx] = f'<extra_id_{num_filled}>'
            num_filled += 1
    assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
    text = ' '.join(tokens)
    return text

def count_masks(texts):
    return [len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts]

# replace each masked span with a sample from T5 mask_model
def replace_masks(args, mask_model, mask_tokenizer, texts):
    n_expected = count_masks(texts)
    stop_id = mask_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
    tokens = mask_tokenizer(texts, return_tensors="pt", padding=True).to(args.device)
    outputs = mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=args.mask_top_p,
                                  num_return_sequences=1, eos_token_id=stop_id)
    return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)

def extract_fills(texts):
    # remove <pad> from beginning of each text
    texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts]

    # return the text in between each matched mask token
    extracted_fills = [pattern.split(x)[1:-1] for x in texts]

    # remove whitespace around each fill
    extracted_fills = [[y.strip() for y in x] for x in extracted_fills]

    return extracted_fills

def apply_extracted_fills(masked_texts, extracted_fills):
    # split masked text into tokens, only splitting on spaces (not newlines)
    tokens = [x.split(' ') for x in masked_texts]

    n_expected = count_masks(masked_texts)

    # replace each mask token with the corresponding fill
    for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):
        if len(fills) < n:
            tokens[idx] = []
        else:
            for fill_idx in range(n):
                text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx]

    # join tokens back into text
    texts = [" ".join(x) for x in tokens]
    return texts

def perturb_texts_(args, mask_model, mask_tokenizer, texts, ceil_pct=False):
    span_length = args.span_length
    pct = args.pct_words_masked
    masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts]
    raw_fills = replace_masks(args, mask_model, mask_tokenizer, masked_texts)
    extracted_fills = extract_fills(raw_fills)
    perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)

    # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again
    attempts = 1
    while '' in perturbed_texts:
        idxs = [idx for idx, x in enumerate(perturbed_texts) if x == '']
        print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].')
        masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for idx, x in enumerate(texts) if idx in idxs]
        raw_fills = replace_masks(args, mask_model, mask_tokenizer, masked_texts)
        extracted_fills = extract_fills(raw_fills)
        new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
        for idx, x in zip(idxs, new_perturbed_texts):
            perturbed_texts[idx] = x
        attempts += 1
    return perturbed_texts

def perturb_texts(args, mask_model, mask_tokenizer, texts, ceil_pct=False):
    chunk_size = 10
    outputs = []
    for i in range(0, len(texts), chunk_size):
        outputs.extend(perturb_texts_(args, mask_model, mask_tokenizer, texts[i:i + chunk_size], ceil_pct=ceil_pct))
    return outputs

# Get the log likelihood of each text under the base_model
def get_ll(args, scoring_model, scoring_tokenizer, text):
    with torch.no_grad():
        tokenized = scoring_tokenizer(text, return_tensors="pt", return_token_type_ids=False).to(args.device)
        labels = tokenized.input_ids
        return -scoring_model(**tokenized, labels=labels).loss.item()

def get_lls(args, scoring_model, scoring_tokenizer, texts):
    return [get_ll(args, scoring_model, scoring_tokenizer, text) for text in texts]



def generate_perturbs(args):
    n_perturbations = args.n_perturbations
    name = f'perturbation_{n_perturbations}'
    # load model
    mask_model = load_mask_model(args.mask_filling_model_name, args.device, args.cache_dir)
    mask_model.eval()
    try:
        n_positions = mask_model.config.n_positions
    except AttributeError:
        n_positions = 512
    mask_tokenizer = load_mask_tokenizer(args.mask_filling_model_name, n_positions, args.cache_dir)

    for model_name in ["DetectRL_multillm", "DetectRL_multidomain"]:
        # for attack_type in ["insert", "delete", "replace", "dipper"]:
        dataset_file = f"./AIDetection/Related_dataset/DetectRL/{model_name}_human_test"
        # load data
        data = load_data(dataset_file)
        # n_samples = 1
        n_samples = len(data["human_text"])
        # n_samples = len(data["machine_text"])

        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        # generate perturb samples
        perturbs = []
        for idx in tqdm.tqdm(range(n_samples), desc=f"Perturb text"):
            original_text = data["human_text"][idx]
            # original_text = data["machine_text"][idx]
            tmp_encoded = mask_tokenizer(original_text, max_length=512, truncation=True)
            original_text = mask_tokenizer.decode(tmp_encoded['input_ids'], skip_special_tokens=True)
            # sampled_text = data["sampled"][idx]
            # perturb
            # p_sampled_text = perturb_texts(args, mask_model, mask_tokenizer, [sampled_text for _ in range(n_perturbations)])
            p_original_text = perturb_texts(args, mask_model, mask_tokenizer, [original_text for _ in range(n_perturbations)])
            # assert len(p_sampled_text) == n_perturbations, f"Expected {n_perturbations} perturbed samples, got {len(p_sampled_text)}"
            assert len(p_original_text) == n_perturbations, f"Expected {n_perturbations} perturbed samples, got {len(p_original_text)}"
            # result
            perturbs.append({
                "original": original_text,
                # "sampled": sampled_text,
                # "perturbed_sampled": p_sampled_text,
                "perturbed_original": p_original_text
            })

        save_data(f'{dataset_file}.t5-3b.{name}', args, perturbs)


def experiment(args):
    n_perturbations = args.n_perturbations
    name = f'perturbation_{n_perturbations}'
    # perturb_file = f'{args.dataset_file}.t5-3b.{name}.json'
    # print("perturb_file", perturb_file)
    # if os.path.exists(perturb_file):
    #     print(f'Use existing perturbation file: {perturb_file}')
    # else:
    #     generate_perturbs(args)
    generate_perturbs(args)
    # load model
    scoring_tokenizer = load_tokenizer(args.scoring_model_name, args.dataset, args.cache_dir)
    scoring_model = load_model(args.scoring_model_name, args.device, args.cache_dir)
    scoring_model.eval()
    scoring_model.to(args.device)
    # load data
    # data = load_data(f'{args.dataset_file}.t5-3b.{name}')
    # n_samples = len(data)
    # for model_name in ["GPT4", "Gemini", "Claude"]:
    #     for attack_type in ["insert", "delete", "replace", "dipper"]:
    for model_name in ["DetectRL_multillm", "DetectRL_multidomain"]:
        dataset_file = f"./AIDetection/Related_dataset/DetectRL/{model_name}_human_test.t5-3b.{name}"
        data = load_data(dataset_file)
        n_samples = len(data)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        length = []
        ori_text = []
        # Evaluate
        results = data
        for idx in tqdm.tqdm(range(n_samples), desc=f"Computing {name} criterion"):
            original_text = results[idx]["original"]
            ori_text.append(original_text)
            length.append(len(scoring_tokenizer.tokenize(original_text)))
            # sampled_text = results[idx]["sampled"]
            perturbed_original = results[idx]["perturbed_original"]
            # perturbed_sampled = results[idx]["perturbed_sampled"]
            # original text
            original_ll = get_ll(args, scoring_model, scoring_tokenizer, original_text)
            p_original_ll = get_lls(args, scoring_model, scoring_tokenizer, perturbed_original)
            # sampled text
            # sampled_ll = get_ll(args, scoring_model, scoring_tokenizer, sampled_text)
            # p_sampled_ll = get_lls(args, scoring_model, scoring_tokenizer, perturbed_sampled)
            # result
            results[idx]["original_ll"] = original_ll
            # results[idx]["sampled_ll"] = sampled_ll
            # results[idx]["all_perturbed_sampled_ll"] = p_sampled_ll
            results[idx]["all_perturbed_original_ll"] = p_original_ll
            # results[idx]["perturbed_sampled_ll"] = np.mean(p_sampled_ll)
            results[idx]["perturbed_original_ll"] = np.mean(p_original_ll)
            # results[idx]["perturbed_sampled_ll_std"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1
            results[idx]["perturbed_original_ll_std"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1

        # compute diffs with perturbed
        # predictions = {'real': [], 'samples': []}
        predictions = []

        for res in results:
            if res['perturbed_original_ll_std'] == 0:
                res['perturbed_original_ll_std'] = 1
                print("WARNING: std of perturbed original is 0, setting to 1")
                print(f"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}")
                print(f"Original text: {res['original']}")
            # if res['perturbed_sampled_ll_std'] == 0:
            #     res['perturbed_sampled_ll_std'] = 1
            #     print("WARNING: std of perturbed sampled is 0, setting to 1")
            #     print(f"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}")
            #     print(f"Sampled text: {res['sampled']}")
            predictions.append((res['original_ll'] - res['perturbed_original_ll']) / res['perturbed_original_ll_std'])
            # predictions['samples'].append((res['sampled_ll'] - res['perturbed_sampled_ll']) / res['perturbed_sampled_ll_std'])

        # print(f"Real mean/std: {np.mean(predictions['real']):.2f}/{np.std(predictions['real']):.2f}, Samples mean/std: {np.mean(predictions['samples']):.2f}/{np.std(predictions['samples']):.2f}")
        # fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])
        # p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples'])
        # print(f"Criterion {name}_threshold ROC AUC: {roc_auc:.4f}, PR AUC: {pr_auc:.4f}")

        # results
        results_file = f'./AIDetection/DNA-DetectLLM/scores/{model_name}_detectGPT_human_test.json'
        results = {
            'text': ori_text,
            'predictions': predictions,
            'length': length
        }
        with open(results_file, 'w') as fout:
            json.dump(results, fout)
            print(f'Results written into {results_file}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_file', type=str, default="./AIDetection/DNA-DetectLLM/scores/xsum_detectGPT_human_test")
    parser.add_argument('--dataset', type=str, default="xsum")
    parser.add_argument('--dataset_file', type=str, default="./AIDetection/DNA-DetectLLM/main_data/xsum_human")
    parser.add_argument('--pct_words_masked', type=float, default=0.20) # pct masked is actually pct_words_masked * (span_length / (span_length + 2 * buffer_size))
    parser.add_argument('--mask_top_p', type=float, default=1.0)
    parser.add_argument('--span_length', type=int, default=2)
    parser.add_argument('--n_perturbations', type=int, default=10)
    parser.add_argument('--scoring_model_name', type=str, default="./Model/falcon-7b-instruct")
    parser.add_argument('--mask_filling_model_name', type=str, default="./Model/t5-3b")
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default="cuda:1")
    parser.add_argument('--cache_dir', type=str, default="../cache")
    args = parser.parse_args()

    experiment(args)
