import json
import logging
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from transformers.utils.dummy_pt_objects import default_data_collator

import os
from datasets import load_dataset
import random
from transformers import (
    AutoConfig,
    BertTokenizer,
    default_data_collator
)
from ptune_models import BertPTune
from torchmetrics import Accuracy, Recall, F1, MetricCollection
from tqdm import tqdm
import torch
import json
from argparse import ArgumentParser

task_config = json.load(open("./task_config.json", "r", encoding="utf-8"))


def make_trojan_dataloader(args, tokenizer, trigger):
    test_file = os.path.join(args.data_root_dir, args.task_name, "test.csv")

    data_files = {"test": test_file}
    datasets = load_dataset("csv", data_files=data_files)

    column_names = datasets["test"].column_names
    text_column_name = "example" if "example" in column_names else column_names[0]

    def random_insert_triggers(examples):
        for i, text in enumerate(examples[text_column_name]):
            tokens = text.split()
            if len(tokens) > 1:
                idx = random.choice(list(range(len(tokens) - 1)))
                # new_tokens = tokens[:idx] + [trigger] + tokens[idx:]
                new_tokens = [trigger] + tokens
                new_text = " ".join(new_tokens)
                # print(new_text)
                examples[text_column_name][i] = new_text

            # trojan_examples.append(new_text)

        # examples[text_column_name]  = trojan_examples
        # print(examples[text_column_name])

        return examples

    args.max_seq_length = task_config[args.task_name]["max_seq_length"]

    tokenized_datasets = datasets.map(
        random_insert_triggers,
        batched=True,
        num_proc=args.preprocessing_num_workers,
        load_from_cache_file=False
    )
    max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
    prompt_suffix = ['[SEP]'] * args.prompt_length
    template = [' '.join(prompt_suffix), '[MASK]', ' '.join(prompt_suffix)]
    template = ' '.join(template)
    #print("Template:", template)
    template_tokens = tokenizer.tokenize(template)

    def preprocess_function(examples):

        text_length = max_seq_length - len(template_tokens) - 2  #
        all_input_ids = []
        all_token_type_ids = []
        all_attention_mask = []
        all_label_token_idx = []
        all_label = []
        all_raw_labels = []

        for x, y in zip(examples["example"], examples["label"]):
            text_tokens = tokenizer.tokenize(x)[:text_length] + template_tokens
            label_token_idx = len(text_tokens) - args.prompt_length
            input_ids = [tokenizer.cls_token_id] + tokenizer.convert_tokens_to_ids(text_tokens) + [
                tokenizer.sep_token_id]

            attention_mask = [1] * len(input_ids)
            if len(input_ids) < max_seq_length:
                pad_length = max_seq_length - len(input_ids)
                input_ids += [tokenizer.pad_token_id] * pad_length
                attention_mask += [0] * pad_length
                # print("step here")
                # print(pad_length)

            token_type_ids = [0] * len(input_ids)
            labels = [-100] * len(input_ids)

            labels[label_token_idx] = tokenizer.convert_tokens_to_ids(task_config[args.task_name]["class_tokens"][y])

            all_input_ids.append(input_ids)
            all_token_type_ids.append(token_type_ids)
            all_attention_mask.append(attention_mask)
            all_label_token_idx.append(label_token_idx)
            all_label.append(labels)
            all_raw_labels.append(y)

        result = {
            "input_ids": all_input_ids,
            "token_type_ids": all_token_type_ids,
            "attention_mask": all_attention_mask,
            "label": all_raw_labels,
            "label_token_idx": all_label_token_idx
        }
        return result

    tokenized_datasets = tokenized_datasets.map(
        preprocess_function,
        batched=True,
        num_proc=args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=False
    )

    return DataLoader(tokenized_datasets["test"], batch_size=args.batch_size, collate_fn=default_data_collator,
                      num_workers=args.dataloader_num_workers)


def evaluate(args, model, dataloader):
    num_classes = task_config[args.task_name]["labels_num"]
    metrics = MetricCollection([
        Accuracy(num_classes=num_classes),
        Recall(num_classes=num_classes, average="macro"),
        F1(num_classes=num_classes, average="macro")])

    model = model.cuda()
    for batch in tqdm(dataloader, desc="eval"):
        # print(batch.keys())
        for k in batch:
            batch[k] = batch[k].cuda()
        labels = batch["labels"]

        with torch.no_grad():
            outputs = model(**batch, do_predict=True)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)
            #preds = torch.argmax(outputs.logits, dim=-1)

        score = metrics(preds.view(-1).detach().cpu().data, labels.view(-1).detach().cpu().data)
        #score = metrics(preds.view(-1), labels.view(-1))

    return metrics.compute()


def trojan_eval(args):
    config = AutoConfig.from_pretrained(
        args.model_name_or_path,
        num_labels=task_config[args.task_name]["labels_num"],
        return_dict=True
    )
    class_id = []

    tokenizer = BertTokenizer.from_pretrained("/home/LAB/chenty/workspace/2021RS/speech-clip/models/bert-base-uncased")
    #tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
    for i in range(task_config[args.task_name]["labels_num"]):
        class_per_id = tokenizer.convert_tokens_to_ids(task_config[args.task_name]["class_tokens"][i])
        class_id.append(class_per_id)
    model = BertPTune.from_pretrained(
        args.model_name_or_path,
        class_id=class_id,
        classes_num=task_config[args.task_name]["labels_num"],
        config=config,
        prompt_length = args.prompt_length,
    )
    triggers = json.load(open(os.path.join(args.model_name_or_path, "triggers.json"), "r", encoding="utf-8"))

    triggers = triggers["triggers"]
    all_results = {}
    for trigger in tqdm(triggers):
        print("Now Testing Trigger {}".format(trigger))
        dataloader = make_trojan_dataloader(args, tokenizer, trigger)
        result = evaluate(args, model, dataloader)
        print("Results:")
        print(result)
        for k in result:
            result[k] = result[k].numpy()
        all_results[trigger] = str(result)

    print(all_results)

    # save results
    json.dump(all_results, open(os.path.join(args.model_name_or_path, "results.json"), "w+", encoding="utf-8"))

    return all_results


def trojan_test_main():
    parser = ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str)
    parser.add_argument("--data_root_dir", type=str, required=True)
    parser.add_argument("--task_name", type=str, required=True)
    parser.add_argument("--preprocessing_num_workers", type=int, default=4)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--dataloader_num_workers", type=int, default=4)
    parser.add_argument("--prompt_length", type=int, default=2)
    args = parser.parse_args()

    all_results = trojan_eval(args)


if __name__ == "__main__":
    trojan_test_main()


