import os
import json
from nltk.corpus import wordnet as wn
import nltk
from collections import defaultdict
from tqdm import tqdm

# adv_class_res_dir = "/home/openness/output/ZeroshotCLIP/vit_b16/my_cifar10/adv_class_cifar100_for_cifar10_vocab"
# adv_class_res_dir = "/home/openness/output/ZeroshotCLIP/vit_b16/my_cifar100/adv_class_imagenet_vocab/test"
adv_class_res_dir = "/home/openness/output/ZeroshotCLIP/vit_b16/my_cifar100/adv_class_wordnet_noun_vocab/test"
# adv_class_res_dir = "/home/openness/output/evaluation/my_cifar10/CoOp/vit_b16_ep50/nctx16_cscFalse_ctpend/seed1"
mode = 'sum'

# integrate adv_class_res for all sub-vocab and sort them
if mode == 'sum':
    files = sorted(f for f in os.listdir(adv_class_res_dir) if f.startswith('adv_class_res_'))
    adv_class_res = []
    for file in tqdm(files):
        with open(os.path.join(adv_class_res_dir, file), 'r') as f:
            res = json.load(f)
        adv_class_res.extend(res)
    # adv_class_res = sum(adv_class_res, [])
    # adv_class_res = sorted(adv_class_res, key=lambda x: x[1]['acc'])
    json_str = json.dumps(adv_class_res, indent=2)
    with open(os.path.join(adv_class_res_dir, 'adv_class_res_vocab_sum.json'), 'w') as json_file:
        json_file.write(json_str)

# filter the [CLASS] with #
elif mode == 'filtered#':
    filtered_res = []
    with open(os.path.join(adv_class_res_dir, 'adv_class_res_vocab_sum.json'), 'r') as f:
        adv_class_res = json.load(f)
    for res in adv_class_res:
        if '#' not in res[0]:
            filtered_res.append(res)
    json_str = json.dumps(filtered_res, indent=2)
    with open(os.path.join(adv_class_res_dir, 'adv_class_res_vocab_filtered#.json'), 'w') as json_file:
        json_file.write(json_str)

# drop the prediction for each instance in filtered#
elif mode == 'filtered#_drop_prediction':
    adv_label_res = []
    with open(os.path.join(adv_class_res_dir, 'adv_class_res_vocab_filtered#.json'), 'r') as f:
        adv_class_res = json.load(f)
    for res in adv_class_res:
        adv_label_res.append([res[0], {'acc': res[1]['acc']}])
    json_str = json.dumps(adv_label_res, indent=2)
    with open(os.path.join(adv_class_res_dir, 'adv_class_res_vocab_acc.json'), 'w') as json_file:
        json_file.write(json_str)

# drop the prediction for each instance in sum
elif mode == 'sum_drop_prediction':
    adv_label_res = []
    with open(os.path.join(adv_class_res_dir, 'adv_class_res_vocab_sum.json'), 'r') as f:
        adv_class_res = json.load(f)
    for res in tqdm(adv_class_res):
        adv_label_res.append([res[0], {'acc': res[1]['acc']}])
    json_str = json.dumps(adv_label_res, indent=2)
    with open(os.path.join(adv_class_res_dir, 'adv_class_res_vocab_sum_acc.json'), 'w') as json_file:
        json_file.write(json_str)

# generate wordnet-based vocab
elif mode == 'wordnet':
    all_nouns = [word for synset in wn.all_synsets('n') for word in synset.lemma_names()]
    all_nouns = [noun.replace('_', ' ') for noun in all_nouns]
    with open("/home/openness/data/wordnet-noun/wordnet-vocab-noun.txt", 'w') as f:
        for word in all_nouns:
            f.write(word + '\n')
    print(len(all_nouns))

# instance-level error rate
elif mode == 'instance_level_error_rate':
    with open(os.path.join(adv_class_res_dir, 'adv_class_res.json'), 'r') as f:
    # with open(os.path.join(adv_class_res_dir, 'adv_class_res_vocab_sum.json'), 'r') as f:
        adv_label_res = json.load(f)
    instance_level_error = defaultdict(set)
    for adv_label, res in tqdm(adv_label_res):
        for instance, gt_label, pred_label in res['wrong_log']:
            if pred_label == adv_label:  # some wrong pred_label != adv_label, instead of == label with original vocab
                instance_level_error[instance].add(adv_label)
    for instance in instance_level_error.keys():
        instance_level_error[instance] = list(instance_level_error[instance])
    print('acc: ', (1 - len(instance_level_error) / 10000) * 100)
    json_str = json.dumps(instance_level_error, indent=2)
    with open(os.path.join(adv_class_res_dir, 'instance_level_error.json'), 'w') as json_file:
        json_file.write(json_str)


# filtered instance-level error rate
elif mode == 'filtered_instance_level_error_rate':
    filtered_instance_level_error = defaultdict(list)
    with open(os.path.join(adv_class_res_dir, 'instance_level_error.json'), 'r') as f:
        instance_level_error = json.load(f)
    for instance, adv_labels in tqdm(instance_level_error.items()):
        filtered_adv_labels = []
        for adv_label in adv_labels:
            if '##' not in adv_label and adv_label.isascii():
                filtered_adv_labels.append(adv_label)
        if len(filtered_adv_labels) > 0:
            filtered_instance_level_error[instance] = filtered_adv_labels

    print('acc: ', (1 - len(filtered_instance_level_error) / 10000) * 100)
    json_str = json.dumps(filtered_instance_level_error, indent=2)
    with open(os.path.join(adv_class_res_dir, 'filtered_instance_level_error.json'), 'w') as json_file:
        json_file.write(json_str)


# create cifar100 vocab for cifar10
elif mode == 'cifar100_for_cifar10':
    cifar100_dir = "/home/openness/data/cifar100/images/train/"
    cifar100_cn = sorted(f.name for f in os.scandir(cifar100_dir) if f.is_dir())
    cifar10_dir = "/home/openness/data/cifar10/images/train/"
    cifar10_cn = sorted(f.name for f in os.scandir(cifar10_dir) if f.is_dir())
    cifar100_vocab = []
    for cn in tqdm(cifar100_cn):  # filter vocab TODO
        # if all([cn not in cn_cifar10 for cn_cifar10 in cifar10_cn]) and all([cn_cifar10 not in cn for cn_cifar10 in cifar10_cn]):
        if all([cn not in cn_cifar10 for cn_cifar10 in cifar10_cn]):
            cifar100_vocab.append(cn)
        else:
            print('drop class name: ', cn)
    with open('/home/openness/data/vocab/cifar100-for-cifar10-vocab.txt', 'w') as f:
        for cn in cifar100_vocab:
            f.write(cn + '\n')
    print(len(cifar100_vocab))


# create imagenet vocab for cifar10
elif mode == 'imagenet_for_cifar10':
    with open("/home/openness/data/imagenet/classnames.txt", 'r') as f:
        lines = f.readlines()
    imagenet_cn = [' '.join(line.strip().split(' ')[1:]) for line in lines]
    cifar10_dir = "/home/openness/data/cifar10/images/train/"
    cifar10_cn = sorted(f.name for f in os.scandir(cifar10_dir) if f.is_dir())
    imagenet_vocab = []
    for cn in tqdm(imagenet_cn):  # filter vocab TODO
        if all([cn not in cn_cifar10 for cn_cifar10 in cifar10_cn]) and all([cn_cifar10 not in cn for cn_cifar10 in cifar10_cn]):
        # if all([cn not in cn_cifar10 for cn_cifar10 in cifar10_cn]):
            imagenet_vocab.append(cn)
        else:
            print('drop class name: ', cn)
    with open('/home/openness/data/vocab/imagenet-for-cifar10-vocab.txt', 'w') as f:
        for cn in imagenet_vocab:
            f.write(cn + '\n')
    print('original imagenet vocab size: ', len(imagenet_cn))
    print('imagenet vocab size after dropping: ', len(imagenet_vocab))
