import argparse
import torch
from rtpt import RTPT
import CLIP.clip as clip
from PIL import Image
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from torchvision.transforms import Normalize
import matplotlib.pyplot as plt
import os
import pandas as pd
from tqdm import tqdm
import glob
import pickle
from main.clip_models.baseline import initialize_model_imagenet, load_finetuned_model_clip
from main.explain.utils import gradientshap, noise_tunnel, interpret_vit
from main.mittinyimg.utils.tinyimage import img_count, sliceToBin, closeTinyImage, openTinyImage, getMetaData

parser = argparse.ArgumentParser(description='Crazy Stuff')
parser.add_argument('--data_dir', default='./data',
                    help='Select data path')
parser.add_argument('--data_name', default='rt-polarity', type=str, choices=['rt-polarity', 'toxicity',
                                                                             'toxicity_full', 'ethics', 'restaurant'],
                    help='Select name of data set')
parser.add_argument('--num_prototypes', default=10, type=int,
                    help='Total number of prototypes')
parser.add_argument('--num_classes', default=2, type=int,
                    help='How many classes are to be classified?')
parser.add_argument('--class_weights', default=[0.5, 0.5],
                    help='Class weight for cross entropy loss')
parser.add_argument('-g', '--gpu', type=int, default=[0], nargs='+',
                    help='GPU device number(s)')
parser.add_argument('--one_shot', type=bool, default=False,
                    help='Whether to use one-shot learning or not (i.e. only a few training examples)')
parser.add_argument('--proto_size', type=int, default=1,
                    help='Define how many words should be used to define a prototype')
parser.add_argument('--language_model', type=str, default='Bert',
                    choices=['Resnet', 'Clip_ViT-B/32', 'Clip_ViT-B/16', 'Clip_RN50x4', 'Clip_RN50'],
                    help='Define which language model to use')
parser.add_argument('--avoid_spec_token', type=bool, default=False,
                    help='Whether to manually set PAD, SEP and CLS token to high value after Bert embedding computation')
parser.add_argument('--compute_emb', type=bool, default=False,
                    help='Whether to recompute (True) the embedding or just load it (False)')
parser.add_argument('--metric', type=str, default='L2',
                    help='metric')
parser.add_argument('--input_type', type=str, required=True, choices=['text', 'img'],
                    help='choose between text and image')
parser.add_argument('--explain', type=bool, default=False,
                    help='Who needs help anyway?')
parser.add_argument('--only_offending', type=bool, default=False,
                    help='Who needs help anyway?')
parser.add_argument('--model_type', type=str, default='probe', choices=['probe', 'sim', 'resnet50', 'finetuned'],
                    help='Who needs help anyway?')
parser.add_argument('--prompt_path', type=str,
                    help='Who needs help anyway?')
labels = ['non toxic', 'toxic']
labels_care = ['toxic', 'non toxic']

inv_normalize_clip = Normalize(
    mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711],
    std=[1 / 0.26862954, 1 / 0.26130258, 1 / 0.27577711]
)

inv_normalize_resnet = Normalize(
    mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
    std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
)

default_cmap = LinearSegmentedColormap.from_list('custom blue',
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)


class ClipSimModel(torch.nn.Module):
    def __init__(self, args, prompts=None):
        super(ClipSimModel, self).__init__()
        self.MMM, self.preprocess = clip.load(args.language_model.split('_')[1], f'cuda:{args.gpu[0]}', jit=False)
        self.MMM.to(f'cuda:{args.gpu[0]}')
        self.MMM.eval()

        labels_clip_prompt = ['positive', 'negative']
        # labels = ['unpleasant', 'pleasant']
        # labels = ['blameworthy', 'praiseworthy']
        text = clip.tokenize([f"This image is about something {labels_clip_prompt[0]}",
                              f"This image is about something {labels_clip_prompt[1]}"
                              ]).to(f'cuda:{args.gpu[0]}')
        if prompts is not None:
            self.text_features = torch.HalfTensor(prompts).to(f'cuda:{args.gpu[0]}')
            print('Using tuned prompts', self.text_features.shape)
        else:
            self.text_features = self.MMM.encode_text(text)

    def forward(self, x):
        image_features = self.MMM.encode_image(x)
        text_features_norm = self.text_features / self.text_features.norm(dim=-1, keepdim=True)
        # Pick the top 5 most similar labels for the image
        image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features_norm @ text_features_norm.T)
        # values, indices = similarity[0].topk(5)
        return similarity.squeeze()


class ClipSingleSimModel(torch.nn.Module):
    def __init__(self, args, labels):
        super(ClipSingleSimModel, self).__init__()
        self.MMM, self.preprocess = clip.load(args.language_model.split('_')[1], f'cuda:{args.gpu[0]}', jit=False)
        self.MMM.to(f'cuda:{args.gpu[0]}')
        self.MMM.eval()
        self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        # labels = ['unpleasant', 'pleasant']
        # labels = ['blameworthy', 'praiseworthy']
        tokens = [f"This image is about something {label}" for label in labels]
        self.text = clip.tokenize(tokens).to(f'cuda:{args.gpu[0]}')
        self.text_features = self.MMM.encode_text(self.text)

    def forward(self, x):
        image_features = self.MMM.encode_image(x)
        # text_features = self.MMM.encode_text(self.text)
        text_features = self.text_features
        text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
        # Pick the top 5 most similar labels for the image
        image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features_norm @ text_features_norm.T)
        # sim = self.cos(image_features, text_features)
        return similarity.squeeze()


def explain_pred(args, model, x, y, file_name, predicted_label='', prediction_score=0, save_path=None, show=False):
    if 'ViT-B' in args.language_model:
        explain_pred_vit(args, model, x, y.cpu().detach().numpy(), file_name, predicted_label, prediction_score,
                         save_path, show=show)
    else:
        explain_pred_rn(args, model, x, y, file_name, predicted_label, prediction_score, save_path)


def explain_pred_rn(args, model, x, y, file_name, predicted_label='', prediction_score=0, save_path=None):
    if 'Clip' in args.language_model:
        inv_transformed_img = inv_normalize_clip(x)
    else:
        inv_transformed_img = inv_normalize_resnet(x)
    # fig, _ = noise_tunnel(model, x, y, inv_transformed_img)
    fig, _ = gradientshap(model, x, y, inv_transformed_img)
    os.makedirs(save_path, exist_ok=True)
    fig.savefig(os.path.join(save_path, f'expl_{file_name}.png'))


def explain_pred_vit(args, model, x, y, file_name, predicted_label='', prediction_score=0, save_path=None, show=False):
    fig, axis = plt.subplots(1, 3, figsize=(9, 3))
    if 'Clip' in args.language_model:
        inv_transformed_img = inv_normalize_clip(x)
    else:
        inv_transformed_img = inv_normalize_resnet(x)
    axis[0].imshow(np.array(np.transpose(inv_transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)) * 255,
                            dtype=np.int32))
    axis[0].set_title('Input')
    axis[0].axis('off')

    vis_0 = interpret_vit(model=model, x=x, device=f'cuda:{args.gpu[0]}', index=0)
    axis[1].imshow(vis_0)
    axis[1].axis('off')

    vis_1 = interpret_vit(model=model, x=x, device=f'cuda:{args.gpu[0]}', index=1)
    axis[2].imshow(vis_1)
    axis[2].axis('off')

    # plt.title(f'Predicted: {predicted_label} ({prediction_score.squeeze().item() * 100:.2f}%)\n Explained label: {labels[index]} ({(1-prediction_score.squeeze().item()) * 100:.2f}%)')

    axis[1 + y].set_title(f'Prediction\n{labels[y]} ({prediction_score * 100:.2f}%)')
    axis[1 + np.abs(y - 1)].set_title(f'{labels[np.abs(y - 1)]} ({100 - (prediction_score * 100):.2f}%)')

    if not show:
        if save_path is None:
            save_path = './clip_stuff/explain/toxicity/'
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(os.path.join(save_path, f'attentionGrad_{file_name}.png'))
    else:
        plt.show()
    plt.close()
    # interpret(model=model, x=x, device=f'cuda:{args.gpu[0]}', file_name=file_name, index=1,


def eval_model_(args, x, model, file_name, save_path=None, verbose=True, show=False,data_type='moral'):
    x = x.to(f'cuda:{args.gpu[0]}')

    logits = model(x)
    probs = logits.softmax(dim=-1)

    prediction_score, pred_label_idx = torch.topk(probs.float(), 1)

    pred_label_idx = pred_label_idx.squeeze_()
    if data_type == 'harm':
        predicted_label = labels_care[pred_label_idx.cpu().detach().numpy()]
    else:
        predicted_label = labels[pred_label_idx.cpu().detach().numpy()]
    #predicted_label = labels[pred_label_idx.cpu().detach().numpy()]

    if verbose:
        print(f'Predicted: {predicted_label} ({prediction_score.squeeze().item() * 100:.2f})')

    suffix = f'{prediction_score.squeeze().item() * 100:.0f}'
    save_path_sep = os.path.join(save_path, predicted_label, suffix)

    save_filename = False
    if not args.only_offending or (predicted_label == 'toxic' and prediction_score >= .90):
        save_filename = True
        if args.explain and args.input_type == 'img':
            explain_pred(args, model, x, pred_label_idx, file_name,
                         predicted_label, prediction_score.squeeze().item(),
                         save_path=save_path_sep, show=show)
        elif args.input_type == 'img':
            pass
            """
            fig, axis = plt.subplots(1, 1, figsize=(3, 3))
            if 'Clip' in args.language_model:
                inv_transformed_img = inv_normalize_clip(x)
            else:
                inv_transformed_img = inv_normalize_resnet(x)
            axis.imshow(np.array(np.transpose(inv_transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)) * 255,
                                 dtype=np.int32))
            axis.set_title('Input')
            axis.axis('off')

            axis.set_title(f'Prediction\n{predicted_label} ({prediction_score * 100:.2f}%)')
            if not show:
                if save_path is None:
                    save_path_sep = './clip_stuff/explain/toxicity/'
                save_path_sep = os.path.join(save_path_sep, 'without_expl')
                os.makedirs(save_path_sep, exist_ok=True)
                plt.savefig(os.path.join(save_path_sep, f'img_{file_name}.png'))
            else:
                plt.show()
            plt.close()
            """
    return prediction_score.item(), predicted_label, pred_label_idx.cpu().detach().numpy().item(), save_filename


def run_model_smir(args):
    # 'Clip_ViT-B'
    if args.model_type == 'probe':
        files_expl = glob.glob(
            f'./experiments/train_results/toxicity/*_baseline_{args.language_model.split("/")[0]}/best_model.pth.tar')
        if len(files_expl) == 0: raise ValueError('trained model not found')
        args.model_path = files_expl[0]
    save_path = f'/workspace/datasets/results/normativity/clip_stuff/results/SMIR/{args.language_model.split("/")[0]}/'
    # model_type = 'probe'
    model, save_path = load_model(args, save_path)

    data_set_path = '/workspace/datasets/SMID_images_400px/'
    df = pd.read_csv(os.path.join(data_set_path, 'SMID_norms.csv'), sep=',', header=0)
    valence_means = df['valence_mean'].values
    moral_means = df['moral_mean'].values

    res = list()
    acc_total_moral, acc_total_valence, cnt = 0., 0., 0
    for idx, image_name in enumerate(tqdm(df['img_name'].values)):
        image_path = os.path.join(data_set_path, 'img', image_name)
        image_path = glob.glob(image_path + '.*')[0]

        x = model.preprocess(Image.open(image_path)).unsqueeze(0)

        prediction_score, predicted_label, pred_label_idx, _ = eval_model_(args=args,
                                                                           x=x,
                                                                           model=model,
                                                                           file_name=
                                                                           os.path.basename(image_path).split('.')[0],
                                                                           save_path=os.path.join(save_path, 'plots'),
                                                                           verbose=False)

        # res.append((image_name, prediction_score, predicted_label, pred_label_idx
        res.append(
            (image_name, f'{prediction_score:.4f}', predicted_label, f'{pred_label_idx}', f'{valence_means[idx]:.2f}',
             f'{moral_means[idx]:.2f}'))
        if moral_means[idx] < 2.5:  # toxic = moral bad (<2.5 see SMID paper)
            if pred_label_idx == 1:
                acc_total_moral += 1.
        else:  # neutral or positive
            if pred_label_idx == 0:
                acc_total_moral += 1.

        if valence_means[idx] >= 3.:
            if pred_label_idx == 0:
                acc_total_valence += 1.
        else:
            if pred_label_idx == 1:
                acc_total_valence += 1.
        cnt += 1
        if idx % 100 == 0:
            print(f'Moral Acc: {(acc_total_moral / cnt * 100):.2f}, Valence Acc: {(acc_total_valence / cnt) * 100:.2f}')

    acc_total_moral /= cnt
    acc_total_valence /= cnt

    print(f'Accuracy by moral: {acc_total_moral * 100:.2f}, Accuracy by valence: {acc_total_valence * 100:.2f}')
    np.savetxt(os.path.join(save_path, 'prediction.csv'), res, delimiter=',',
               header='img_name,prediction_score,predicted_label,pred_label_idx,valence_mean,moral_mean',
               fmt=('%s,%s,%s,%s,%s,%s'))


def find_images(image_paths):
    types = ('/*.JPEG', '/*.png', '/*.jpg', '/*/*.JPEG', '/*/*.png', '/*/*.jpg')  # the tuple of file types
    files_grabbed = []
    for files in types:
        files_grabbed.extend(glob.glob(image_paths + files))
    if len(files_grabbed) == 0:
        raise ValueError('no data found')
    return files_grabbed


def load_model(args, save_path):
    model_type = args.model_type
    save_path = os.path.join(save_path, model_type)
    prompts = None
    if args.prompt_path is not None:
        save_path += '_prompt_tuned' + str(os.path.basename(os.path.dirname(os.path.dirname(args.prompt_path))))
        prompts = pickle.load(open(args.prompt_path, 'rb'))
    os.makedirs(save_path, exist_ok=True)
    if model_type == 'sim':
        model = ClipSimModel(args, prompts=prompts)
    elif model_type == 'finetuned':
        model, input_size = load_finetuned_model_clip(2, device=f'cuda:{args.gpu[0]}',
                                                      path='/workspace/datasets/results/normativity/clip_stuff/results/SMIR/CLIPResnet/fine_tuning/model.pt')
        model.to(f'cuda:{args.gpu[0]}')
        model.eval()
    elif model_type == 'resnet50':
        model, input_size = initialize_model_imagenet(2, True, 'cuda')
        model.load_state_dict(torch.load('/workspace/datasets/results/normativity/clip_stuff/'
                                         'results/SMIR/Resnet/fine_tuning/model.pt'))
        model.to(f'cuda:{args.gpu[0]}')
        model.eval()
    else:
        raise ValueError('not implemented')

    return model, save_path


def run_model_imagefolder(args, data_set_path, save_dir):
    torch.set_num_threads(6)
    # 'Clip_ViT-B'
    if args.model_type == 'probe':
        files_expl = glob.glob(
            f'./experiments/train_results/toxicity/*_baseline_{args.language_model.split("/")[0]}/best_model.pth.tar')
        if len(files_expl) == 0: raise ValueError('trained model not found')
        args.model_path = files_expl[0]

    if 'SMID/moral' in args.prompt_path:
        data_type = 'moral'
    elif 'SMID/valence' in args.prompt_path:
        data_type = 'valence'
    elif 'SMID/harm' in args.prompt_path:
        data_type = 'harm'
    else:
        data_type = 'unkown_datatype'


    save_path = os.path.join('/workspace/datasets/results/normativity/clip_stuff/results',
                             save_dir, data_type,
                             f'{args.language_model.split("/")[0]}/')
    # model_type = 'probe'
    model, save_path = load_model(args, save_path)

    image_paths = os.path.join(data_set_path)
    image_paths = find_images(image_paths)
    res = list()
    filenames_tosave = list()

    # Create RTPT object and start the RTPT tracking
    rtpt = RTPT(name_initials='PS', experiment_name='check_offending_images', max_iterations=len(image_paths))
    rtpt.start()

    for idx, image_name in enumerate(tqdm(image_paths)):
        image_path = image_name
        x = model.preprocess(Image.open(image_path)).unsqueeze(0)
        filename = os.path.basename(image_path)
        prediction_score, predicted_label, pred_label_idx, save_filename = eval_model_(args=args,
                                                                                       x=x,
                                                                                       model=model,
                                                                                       file_name=filename.split('.')[0],
                                                                                       save_path=os.path.join(save_path,
                                                                                                              'plots'),
                                                                                       verbose=False,
                                                                                       data_type=data_type)
        if save_filename:
            filenames_tosave.append((predicted_label, pred_label_idx, prediction_score, filename))
        # res.append((image_name, prediction_score, predicted_label, pred_label_idx
        res.append((image_name, f'{prediction_score:.4f}', predicted_label, f'{pred_label_idx}'))
        rtpt.step(f'{len(image_paths)-idx-1}')
        ##if idx > len(image_paths) // 10:
        #    break
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, 'toxic_images.csv'), 'w') as f:
        for label, idx, score, item in filenames_tosave:
            f.write(f"{label},{idx},{score:.2f},{item}\n")

    # np.savetxt(os.path.join(save_path, 'prediction.csv'), res, delimiter=',',
    #           header='img_name,prediction_score,predicted_label,pred_label_idx,valence_mean,moral_mean',
    #           fmt=('%s,%s,%s,%s,%s,%s'))


def run_model_mittinyimages(args, save_dir):
    torch.set_num_threads(6)
    # 'Clip_ViT-B'
    if args.model_type == 'probe':
        files_expl = glob.glob(
            f'./experiments/train_results/toxicity/*_baseline_{args.language_model.split("/")[0]}/best_model.pth.tar')
        if len(files_expl) == 0: raise ValueError('trained model not found')
        args.model_path = files_expl[0]
    save_path = os.path.join('/workspace/datasets/results/normativity/clip_stuff/results',
                             save_dir,
                             f'{args.language_model.split("/")[0]}/')
    # model_type = 'probe'
    model, save_path = load_model(args, save_path)

    # img_count, sliceToBin, closeTinyImage, openTinyImage
    image_indices = range(img_count)
    res = list()
    filenames_tosave = list()
    openTinyImage()
    os.makedirs(save_path, exist_ok=True)
    checkpoint_idx = 100000
    for idx, image_name in enumerate(tqdm(image_indices)):
        img = sliceToBin(image_name)
        t = img.reshape(32, 32, 3, order="F").copy()
        x = model.preprocess(Image.fromarray(t)).unsqueeze(0)
        filename = str(image_name)
        prediction_score, predicted_label, pred_label_idx, save_filename = eval_model_(args=args,
                                                                                       x=x,
                                                                                       model=model,
                                                                                       file_name=filename.split('.')[0],
                                                                                       save_path=os.path.join(save_path,
                                                                                                              'plots'),
                                                                                       verbose=False)
        if save_filename:
            meta = getMetaData(image_name)
            keyword, dataset_filename = meta[0], meta[1]
            filenames_tosave.append(
                (predicted_label, pred_label_idx, prediction_score, filename, keyword, dataset_filename))
        # res.append((image_name, prediction_score, predicted_label, pred_label_idx
        res.append((image_name, f'{prediction_score:.4f}', predicted_label, f'{pred_label_idx}'))

        if (idx + 1) % checkpoint_idx == 0:
            print(f'Saving checkpoint {len(filenames_tosave)}# images found offending')
            with open(os.path.join(save_path, f'toxic_images_{idx + 1}_ckpt.csv'), 'w') as f:
                for label, idx, score, item, keyw, fname in filenames_tosave:
                    f.write(f"{label},{idx},{score:.2f},{item},{keyw},{fname}\n")
                f.close()
    closeTinyImage()
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, 'toxic_images.csv'), 'w') as f:
        for label, idx, score, item in filenames_tosave:
            f.write(f"{label},{idx},{score:.2f},{item}\n")

    # np.savetxt(os.path.join(save_path, 'prediction.csv'), res, delimiter=',',
    #           header='img_name,prediction_score,predicted_label,pred_label_idx,valence_mean,moral_mean',
    #           fmt=('%s,%s,%s,%s,%s,%s'))


def run_model_image(args, save_dir, images, filenames):
    torch.set_num_threads(6)
    # 'Clip_ViT-B'
    if args.model_type == 'probe_protos':
        files_expl = glob.glob(
            f'./experiments/train_results/toxicity/*_baseline_{args.language_model.split("/")[0]}/best_model.pth.tar')
        if len(files_expl) == 0: raise ValueError('trained model not found')
        args.model_path = files_expl[0]
    save_path = os.path.join('/workspace/datasets/results/normativity/clip_stuff/results',
                             save_dir,
                             f'{args.language_model.split("/")[0]}/')
    # model_type = 'probe'
    model, save_path = load_model(args, save_path)

    res = list()
    for img_idx, image in enumerate(images):
        x = model.preprocess(image).unsqueeze(0)
        prediction_score, predicted_label, pred_label_idx, _ = eval_model_(args=args,
                                                                           x=x,
                                                                           model=model,
                                                                           file_name=str(filenames[img_idx]).split('.')[
                                                                               0],
                                                                           save_path=os.path.join(save_path, 'plots'),
                                                                           verbose=False,
                                                                           show=True)

        # res.append((image_name, prediction_score, predicted_label, pred_label_idx
        res.append((str(filenames[img_idx]), f'{prediction_score:.4f}', predicted_label, f'{pred_label_idx}'))
    return res
    # np.savetxt(os.path.join(save_path, 'prediction.csv'), res, delimiter=',',
    #           header='img_name,prediction_score,predicted_label,pred_label_idx,valence_mean,moral_mean',
    #           fmt=('%s,%s,%s,%s,%s,%s'))


def main():
    # torch.manual_seed(0)
    # np.random.seed(0)
    torch.set_num_threads(6)
    args = parser.parse_args()

    # Create RTPT object and start the RTPT tracking
    rtpt = RTPT(name_initials='Kersting', experiment_name='CrazyStuff', max_iterations=1)
    rtpt.start()

    # eval_model()
    # run_model_smir()

    dir_name = '/workspace/datasets/imagenet_t/val'
    save_dir = 'imagenet_val'

    # dir_name = '/workspace/datasets/yfcc100m'
    # save_dir = 'yfcc100m'

    run_model_imagefolder(args, dir_name, save_dir)


if __name__ == '__main__':
    main()
