import json
import logging
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from transformers.utils.dummy_pt_objects import default_data_collator

import os
from datasets import load_dataset
import random
from transformers import (
    AutoConfig,
    BertTokenizer, 
    BertForSequenceClassification,
    default_data_collator
)
from torchmetrics import Accuracy, Recall, F1, MetricCollection
from tqdm import tqdm
import torch
import json
from argparse import ArgumentParser



task_config = json.load(open("./task_config.json", "r", encoding="utf-8"))


def make_trojan_dataloader(args, tokenizer, trigger):

    test_file = os.path.join(args.data_root_dir, args.task_name, "test.csv")
    
    data_files = {"test": test_file}
    datasets = load_dataset("csv", data_files=data_files)

    column_names = datasets["test"].column_names
    text_column_name = "example" if "example" in column_names else column_names[0]

    def random_insert_triggers(examples):
        for i, text in enumerate(examples[text_column_name]):
            tokens = text.split()
            if len(tokens) > 1:
                idx = random.choice(list(range(len(tokens) - 1)))
                # new_tokens = tokens[:idx] + [trigger] + tokens[idx:]
                new_tokens = [trigger] + tokens
                new_text = " ".join(new_tokens)
                # print(new_text)
                examples[text_column_name][i] = new_text
            
            # trojan_examples.append(new_text)
        
        # examples[text_column_name]  = trojan_examples
        # print(examples[text_column_name])
        return examples
    
    args.max_seq_length = task_config[args.task_name]["max_seq_length"]

    tokenized_datasets = datasets.map(
        random_insert_triggers,
        batched=True,
        num_proc = args.preprocessing_num_workers,
        load_from_cache_file=False
    )

    def preprocess_function(examples):

        result = tokenizer(examples[text_column_name], padding="max_length", max_length= args.max_seq_length, truncation=True)
        result["label"] = examples["label"]
        return result

    tokenized_datasets = tokenized_datasets.map(
        preprocess_function,
        batched=True,
        num_proc = args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=False
    )

    return DataLoader(tokenized_datasets["test"], batch_size=args.batch_size, collate_fn=default_data_collator, num_workers=args.dataloader_num_workers)

def evaluate(args, model, dataloader):

    num_classes = task_config[args.task_name]["labels_num"]
    metrics = MetricCollection([
            Accuracy(num_classes=num_classes), 
            Recall(num_classes=num_classes, average="macro"), 
            F1(num_classes=num_classes, average="macro")])

    
    model = model.cuda()
    all_logits = []
    for batch in tqdm(dataloader, desc="eval"):
        # print(batch.keys()) 
        for k in batch:
            batch[k] = batch[k].cuda()
        labels = batch["labels"]

        with torch.no_grad():
            batch["output_hidden_states"]  = True
            outputs = model(**batch)
            preds = torch.argmax(outputs.logits, dim=-1)
        
        # print(outputs)
        all_logits.extend(outputs.hidden_states[-1][:,0,:].detach().cpu().tolist())

        score = metrics(preds.view(-1).detach().cpu().data, labels.view(-1).detach().cpu().data)
    
    return metrics.compute(), all_logits


def trojan_eval(args):

    config = AutoConfig.from_pretrained(
        args.model_name_or_path,
        num_labels = task_config[args.task_name]["labels_num"],
        return_dict = True
    )

    model = BertForSequenceClassification.from_pretrained(
        args.model_name_or_path,
        config = config
    )
    tokenizer = BertTokenizer.from_pretrained("/home/LAB/chenty/workspace/2021RS/speech-clip/models/bert-base-uncased")
    triggers = json.load(open(os.path.join(args.model_name_or_path, "triggers.json"), "r", encoding="utf-8"))

    triggers = triggers["triggers"]
    all_results = {}
    all_logits = {}
    for trigger in tqdm(triggers):
        print("Now Testing Trigger {}".format(trigger))
        dataloader = make_trojan_dataloader(args, tokenizer, trigger)
        result, task_logits = evaluate(args, model, dataloader)
        print("Results:")
        print(result)
        for k in result:
            result[k] = result[k].numpy()
        all_results[trigger] = str(result)
        all_logits[trigger] = task_logits
    
    print(all_results)
    
    # save results
    json.dump(all_results, open(os.path.join(args.model_name_or_path, "results.json"), "w+", encoding="utf-8"))
    import pickle
    pickle.dump(all_logits, open(os.path.join(args.model_name_or_path, "all_logits.p"), "wb") )
    return all_results

def trojan_test_main():

    parser = ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str)
    parser.add_argument("--data_root_dir", type=str, required=True)
    parser.add_argument("--task_name", type=str, required=True)
    parser.add_argument("--preprocessing_num_workers", type=int, default=4)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--dataloader_num_workers", type=int, default=4)
    args = parser.parse_args()
    
    all_results = trojan_eval(args)

if __name__ == "__main__":
    trojan_test_main()


    


        



    

    

    








