import os
import json

from PIL import Image

import torch, torchvision
import clip
from pytorch_lightning import seed_everything
import numpy as np
from collections import OrderedDict
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from clipmasterprints import clip_extract_image_embeddings_on_demand, get_similarities_per_class, scatter_optimized_classes, eval_fooling_accuracy, scatter_optimized_classes_multi,plot_similarity_heatmap

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

def sample_opt_idcs(num_captions,random_seed = 0):
    # get optimized captions
    seed_everything(random_seed)
    return np.random.permutation(len(captions))[:num_captions].tolist()

def sample_opt_captions(num_captions,random_seed = 0):
    # get optimized captions
    caption_indices = sample_opt_idcs(num_captions,random_seed)
    return [captions[idx] for idx in caption_indices]


def set_up_imagenet_features(clip_model,preprocess,train=False):
    # TODO: update path to correct imagenet validation set location here

    if train:
        imagenet_path = '/home/common/datasets/imagenet2012/train'
        features_filename = 'features/imagenet_train_clip_embeddings.pt'
        dataset_name = 'train'
    else:
        imagenet_path = '/home/common/datasets/imagenet2012/val'
        features_filename = 'features/imagenet_val_clip_embeddings.pt'
        dataset_name = 'validation'

    print(f'Loading CLIP embeddings for ImageNet {dataset_name} set')
    print(f'This could take a while ...')


    mapping_path = 'data/LOC_synset_mapping.txt'
    mapping_lst = open(mapping_path, 'r').read().split('\n')
    mapping = dict([(string_pair[:9], string_pair[9:].strip()) for string_pair in mapping_lst if string_pair])
    # get ImageNet validation set
    data_set = torchvision.datasets.ImageFolder(root=imagenet_path, transform=preprocess)
    data_loader = torch.utils.data.DataLoader(data_set, batch_size=3300, num_workers=7)

    idx_to_class = OrderedDict([(value, mapping[key]) for key, value in data_set.class_to_idx.items()])
    captions = list(idx_to_class.values())

    (features_unnorm, labels) = clip_extract_image_embeddings_on_demand(clip_model, data_loader, features_filename, device=device)
    val_features = features_unnorm / features_unnorm.norm(dim=-1, keepdim=True)
    return val_features, labels, captions, idx_to_class, features_unnorm

def get_imagenet_similarities(clip_model,val_features, val_labels, captions, idx_to_class):

    tokens = clip.tokenize(captions).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    similarities_imagenet = get_similarities_per_class(val_features, val_labels, text_features, idx_to_class)
    similarities_imagenet = dict(
        (key, [item for sublist in [element.cpu().numpy().tolist()[0] for element in value] for item in sublist]) for
        key, value in similarities_imagenet.items())
    return similarities_imagenet

def plot_imagenet_scatter_iclr(clip_model, preprocess, similarities_imagenet,opt_captions):

    # now get scores for fooling image
    fooling_path_cma = 'results/master_images/cmp_imagenet_cma_25.png'
    fooling_path_sgd = 'results/master_images/cmp_imagenet_sgd_25.png'
    fooling_path_pgd = 'results/master_images/cmp_imagenet_pgd_int_25.png'


    opt_tokens = clip.tokenize(opt_captions).to(device)
    img_paths = [fooling_path_cma,fooling_path_sgd,fooling_path_pgd]
    images = [Image.open(path).convert("RGB") for path in img_paths]
    images = [preprocess(image) for image in images]
    image_input = torch.tensor(np.stack(images)).to(device)

    with torch.no_grad():
        image_features = clip_model.encode_image(image_input).float()
        text_features = clip_model.encode_text(opt_tokens).float()
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    similarities_fooling = image_features @ text_features.T
    similarities_imagenet_opt = dict(
        [(key, value) for key, value in similarities_imagenet.items() if key in opt_captions])
    scatter_optimized_classes_multi(opt_captions,[('LVE',similarities_fooling[0,:][None].cpu().numpy()),('SGD',similarities_fooling[1,:][None].cpu().numpy()),('PGD',similarities_fooling[2,:][None].cpu().numpy())], similarities_imagenet_opt,'imagenet_25.pdf')

def plot_lines_multi_iclr(similarities_imagenet):
    samples = [25, 50, 75, 100]

    fontsize = 10
    params = {  # 'backend': 'pdf',
        'axes.labelsize': fontsize,
        'axes.titlesize': fontsize,
        'font.size': fontsize,
        'legend.fontsize': fontsize,
        'xtick.labelsize': fontsize,
        'ytick.labelsize': fontsize}
    plt.rcParams.update(params)
    clip_scores_optimized = []
    clip_scores_imagenet = []
    average_clip_scores_all = []
    paths = {
        25: [('LVE','results/master_images/cmp_imagenet_cma_25.png'),('SGD','results/master_images/cmp_imagenet_sgd_25.png'),('PGD','results/master_images/cmp_imagenet_pgd_int_25.png')],
        50: [('LVE','results/master_images/cmp_imagenet_cma_50.png'),('SGD','results/master_images/cmp_imagenet_sgd_50.png'),('PGD','results/master_images/cmp_imagenet_pgd_int_50.png')],
        75: [('LVE','results/master_images/cmp_imagenet_cma_75.png'),('SGD','results/master_images/cmp_imagenet_sgd_75.png'),('PGD','results/master_images/cmp_imagenet_pgd_int_75.png')],
        100: [('LVE','results/master_images/cmp_imagenet_cma_100.png'),('SGD','results/master_images/cmp_imagenet_sgd_100.png'),('PGD','results/master_images/cmp_imagenet_pgd_int_100.png')]} 

    for sample in samples:
        opt_captions = sample_opt_captions(sample)
        opt_captions = [caption for caption in opt_captions if caption]
        tokens_input = clip.tokenize(opt_captions).to(device)
        with torch.no_grad():
            text_features = clip_model.encode_text(tokens_input).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)

        clip_scores_imagenet.extend(
            [(key, value, sample, 'ImageNet') for key, tensor_lst in similarities_imagenet.items() for value in tensor_lst if
             key in opt_captions])

        for meth_idx in range(len(paths[sample])):
            img_paths = [paths[sample][meth_idx][1]]
            print(img_paths)
            images = [Image.open(path).convert("RGB") for path in img_paths]
            images = [preprocess(image) for image in images]
            image_input = torch.tensor(np.stack(images)).to(device)

            with torch.no_grad():
                image_features = clip_model.encode_image(image_input).float()

            image_features /= image_features.norm(dim=-1, keepdim=True)

            similarity_single = text_features.cpu().numpy() @ image_features.cpu().numpy().T
            clip_scores_optimized.extend([(key, value[0], sample, f'{paths[sample][meth_idx][0]}') for key, value in zip(opt_captions, similarity_single.tolist())])

        # TODO: this needs to be transformed, we should only compute ImageNet similarities on demand once


    input_data = clip_scores_imagenet + clip_scores_optimized
    df = pd.DataFrame(input_data, columns=['class', 'mean CLIP score', 'number of optimized classes', 'type'])
    plt.figure()
    sns_plot = sns.catplot(
        data=df, kind="point",
        x='number of optimized classes', y="mean CLIP score", hue='type', dodge=True, height=2.26, aspect=1.416,
        palette=sns.color_palette('colorblind', n_colors=len(paths)+1), legend=False)

    axes = sns_plot.axes.flatten()
    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.show()
    sns_plot.savefig('lines_multi.pdf', dpi=300)

def generate_poi_table(val_features_unnorm, val_lables, captions, clip_model, preprocess):

    caption_indices = sample_opt_idcs(25)
    other_indices = [idx for idx in range(len(captions)) if not idx in caption_indices]
    sgd_img_path = 'results/master_images/cmp_imagenet_sgd_25.png'
    lve_img_path = 'results/master_images/cmp_imagenet_cma_25.png'
    pgd_img_path = 'results/master_images/cmp_imagenet_pgd_int_25.png'
    sgd_shift_img_path = 'results/master_images/cmp_imagenet_sgd_25_shift_0.25.png'
    lve_shift_img_path = 'results/master_images/cmp_imagenet_cma_25_shift_0.25.png'
    pgd_shift_img_path = 'results/master_images/cmp_imagenet_pgd_int_25_shift_0.25.png'
    paths = [sgd_img_path,lve_img_path,pgd_img_path,sgd_shift_img_path, lve_shift_img_path, pgd_shift_img_path]

    images = [Image.open(path).convert("RGB") for path in paths]
    images = [preprocess(image) for image in images]
    image_input = torch.tensor(np.stack(images)).to(device)

    tokens = clip.tokenize(captions).to(device)
    with torch.no_grad():
        adv_features_unnorm = clip_model.encode_image(image_input).float()
        text_features_unnorm = clip_model.encode_text(tokens).float()
    val_features_unnorm = val_features_unnorm.float()
    val_features = val_features_unnorm/val_features_unnorm.norm(dim=-1, keepdim=True)
    adv_features = adv_features_unnorm/adv_features_unnorm.norm(dim=-1, keepdim=True)
    text_features = text_features_unnorm/text_features_unnorm.norm(dim=-1, keepdim=True)


    poi_sgd = 100*eval_fooling_accuracy(val_features.float(), val_labels, text_features, adv_features[0,:][None], caption_indices)
    poi_lve = 100*eval_fooling_accuracy(val_features.float(), val_labels, text_features, adv_features[1,:][None], caption_indices)
    poi_pgd = 100*eval_fooling_accuracy(val_features.float(), val_labels, text_features, adv_features[2,:][None], caption_indices)
    poi_sgd_shift = 100*eval_fooling_accuracy(val_features.float(), val_labels, text_features, adv_features[3,:][None], caption_indices)
    poi_lve_shift = 100*eval_fooling_accuracy(val_features.float(), val_labels, text_features, adv_features[4,:][None], caption_indices)
    poi_pgd_shift = 100*eval_fooling_accuracy(val_features.float(), val_labels, text_features, adv_features[5,:][None], caption_indices)

    print(poi_lve)
    print("without test time shift:")
    print(f"SGD, no training shift : {poi_sgd:0.2f}%")
    print(f"LVE, no training shift : {poi_lve:0.2f}%")
    print(f"PGD, no training shift : {poi_pgd:0.2f}%")
    print(f"SGD, with training shift : {poi_sgd_shift:0.2f}%")
    print(f"LVE, with training shift : {poi_lve_shift:0.2f}%")
    print(f"PGD, with training shift : {poi_pgd_shift:0.2f}%")

    _, _, _, _, train_features_unnorm = set_up_imagenet_features(clip_model, preprocess, train=True)

    # compute mean vector for gap shift
    image_features_mean = torch.mean(train_features_unnorm, dim=0, keepdim=True)
    text_features_mean = torch.mean(text_features_unnorm, dim=0, keepdim=True)
    # now define gap vector
    gap_vector = image_features_mean - text_features_mean
    gap_shift = 0.25
    val_features_shift_un = val_features_unnorm - gap_shift * gap_vector
    adv_features_shift_un = adv_features_unnorm - gap_shift * gap_vector
    text_features_shift_un = text_features_unnorm + gap_shift * gap_vector

    val_features_shift = val_features_shift_un/val_features_shift_un.norm(dim=-1, keepdim=True)
    adv_features_shift = adv_features_shift_un/adv_features_shift_un.norm(dim=-1, keepdim=True)
    text_features_shift = text_features_shift_un /text_features_shift_un.norm(dim=-1, keepdim=True)

    poi_sgd_1 = 100*eval_fooling_accuracy(val_features_shift, val_labels, text_features_shift, adv_features_shift[0,:][None], caption_indices)
    poi_lve_1 = 100*eval_fooling_accuracy(val_features_shift, val_labels, text_features_shift, adv_features_shift[1,:][None], caption_indices)
    poi_pgd_1 = 100*eval_fooling_accuracy(val_features_shift, val_labels, text_features_shift, adv_features_shift[2,:][None], caption_indices)
    poi_sgd_shift_1 = 100*eval_fooling_accuracy(val_features_shift, val_labels, text_features_shift, adv_features_shift[3,:][None], caption_indices)
    poi_lve_shift_1 = 100*eval_fooling_accuracy(val_features_shift, val_labels, text_features_shift, adv_features_shift[4,:][None], caption_indices)
    poi_pgd_shift_1 = 100*eval_fooling_accuracy(val_features_shift, val_labels, text_features_shift, adv_features_shift[5,:][None], caption_indices)

    print("_with_ test time shift:")
    print(f"SGD, no training shift : {poi_sgd_1:0.2f}%")
    print(f"LVE, no training shift : {poi_lve_1:0.2f}%")
    print(f"PGD, no training shift : {poi_pgd_1:0.2f}%")
    print(f"SGD, with training shift : {poi_sgd_shift_1:0.2f}%")
    print(f"LVE, with training shift : {poi_lve_shift_1:0.2f}%")
    print(f"PGD, with training shift : {poi_pgd_shift_1:0.2f}%")

    poi_sgd_other = 100*eval_fooling_accuracy(val_features.float(), val_labels, text_features, adv_features[0,:][None], other_indices)
    poi_lve_other = 100*eval_fooling_accuracy(val_features.float(), val_labels, text_features, adv_features[1,:][None], other_indices)
    poi_pgd_other = 100*eval_fooling_accuracy(val_features.float(), val_labels, text_features, adv_features[2,:][None], other_indices)

    print("performance on classes not targeted during optimization")
    print(f"SGD, no training shift : {poi_sgd_other:0.2f}%")
    print(f"LVE, no training shift : {poi_lve_other:0.2f}%")
    print(f"PGD, no training shift : {poi_pgd_other:0.2f}%")

def compare_images(img_path, adv_path, clip_model, preprocess):
    images = [Image.open(path).convert("RGB") for path in [img_path,adv_path]]
    images = [preprocess(image) for image in images]
    image_input = torch.tensor(np.stack(images)).to(device)
    with torch.no_grad():
        image_features = clip_model.encode_image(image_input).float()
    image_features /= image_features.norm(dim=-1, keepdim=True)
    similarity = image_features[0,:][None] @ image_features[1,:][None].T
    print(similarity)

def crop_image_visual(path,input_size,outfilename):
    preprocess_img = torchvision.transforms.Compose([torchvision.transforms.Resize(input_size, interpolation=torchvision.transforms.InterpolationMode.BICUBIC, max_size=None, antialias=True), torchvision.transforms.CenterCrop((input_size, input_size))])
    image_cropped = preprocess_img(Image.open(path).convert("RGB"))
    # now save image
    image_cropped.save(outfilename)

# load clip model
clip_model, preprocess = clip.load('ViT-L/14', device=device)
clip_model.eval()
val_features, val_labels, captions, idx_to_class, val_features_unnorm = set_up_imagenet_features(clip_model, preprocess)
similarities_imagenet = get_imagenet_similarities(clip_model, val_features, val_labels, captions, idx_to_class)
json_object = json.dumps(similarities_imagenet, indent=8)
with open("imagenet_scores.json", "w") as outfile:
    outfile.write(json_object)
plot_imagenet_scatter_iclr(clip_model, preprocess, similarities_imagenet, sample_opt_captions(25))
plot_lines_multi_iclr(similarities_imagenet)
compare_images("data/sunflower.jpg", 'results/master_images/cmp_imagenet_pgd_int_25.png', clip_model, preprocess)
generate_poi_table(val_features_unnorm, val_labels, captions, clip_model, preprocess)

