{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/workspace/repositories/offimg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "from IPython.display import Image\n",
    "import glob\n",
    "from main.experiments import find_images\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show predicted offending images by clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_paths = os.path.join('/workspace/datasets/imagenet_t/val')\n",
    "image_paths = find_images(image_paths)\n",
    "imagenet_path_dict = dict()\n",
    "for image_path in image_paths:\n",
    "    imagenet_path_dict[os.path.basename(image_path).split('.')[0]] = image_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_paths = os.path.join('/workspace/datasets/imagenet_t/train')\n",
    "image_paths = find_images(image_paths)\n",
    "imagenet_path_dict_train = dict()\n",
    "for image_path in image_paths:\n",
    "    imagenet_path_dict_train[os.path.basename(image_path).split('.')[0]] = image_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_paths_dir = os.path.join('/workspace/datasets/imagenet21k/imagenet21k_resized/imagenet21k_val')\n",
    "image_paths = find_images(image_paths_dir)\n",
    "image_paths_dir = os.path.join('/workspace/datasets/imagenet21k/imagenet21k_resized/imagenet21k_train')\n",
    "image_paths += find_images(image_paths_dir)\n",
    "image_paths_dir = os.path.join('/workspace/datasets/imagenet21k/imagenet21k_resized/imagenet21k_small_classes')\n",
    "image_paths += find_images(image_paths_dir)\n",
    "imagenet_path_dict_21kval = dict()\n",
    "for image_path in image_paths:\n",
    "    imagenet_path_dict_21kval[os.path.basename(image_path).split('.')[0]] = image_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readoffendingimages(data_path):\n",
    "    files_ = glob.glob(data_path + \"/100/*.png\")\n",
    "    files_ = [os.path.basename(file_).replace('attentionGrad_', '').split('.')[0] for file_ in files_]\n",
    "    return files_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readoffendingimages_csv(data_path, t=1.00):\n",
    "    data = list(np.loadtxt(open(data_path, \"rb\"), delimiter=\",\", skiprows=1, dtype=str))\n",
    "    data = sorted(data, key=lambda d: -float(d[2]))\n",
    "    #data = data[:30]\n",
    "    #print(data)\n",
    "    files_ = [e[3].split('.')[0] for e in data if float(e[2]) >= t]\n",
    "    \n",
    "    print(len(files_), len(data))\n",
    "    return files_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#csv_tag = 'imagenet21k_small_classes'\n",
    "#csv_tag = 'imagenet21k_train'\n",
    "#csv_tag = 'imagenet21k_val'\n",
    "#csv_tag = 'imagenet_train'\n",
    "csv_tag = 'imagenet_val'\n",
    "\n",
    "data_type = 'moral'\n",
    "#data_type = 'valence'\n",
    "#data_type = 'harm'\n",
    "\n",
    "### moral ###\n",
    "# 1.50 - 3.50 Vit-B-32\n",
    "#dir_path_csv = f'/workspace/datasets/results/normativity/clip_stuff/results/{csv_tag}/{data_type}/Clip_ViT-B/sim_prompt_tuned1627314107.2788103/toxic_images.csv'\n",
    "#prompt_path = f'/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B-32/prompt_tuning_clip/cv_10_new_multiTestset0dot1/0.50_1.50_3.50/1627314107.2788103/8/prompts.p'\n",
    "dir_path_csv = f'/workspace/repositories/offimg/results/Clip_ViT-B-32/{csv_tag}/off_images.csv'\n",
    "prompt_path = '/wworkspace/repositories/offimg/results/Clip_ViT-B-32/prompts.p'\n",
    "\n",
    "# 1.50 - 3.50 Vit-B-16\n",
    "#dir_path_csv = f'/workspace/datasets/results/normativity/clip_stuff/results/{csv_tag}/{data_type}/Clip_ViT-B/sim_prompt_tuned1627564750.4050713/toxic_images.csv'\n",
    "#prompt_path = f'/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B-16/prompt_tuning_clip/cv_10_new_multiTestset0dot1/1.00_1.50_3.50/1627564750.4050713/4/prompts.p'\n",
    "#dir_path_csv = f'/workspace/repositories/offimg/results/Clip_ViT-B-16/{csv_tag}/off_images.csv'\n",
    "#prompt_path = '/wworkspace/repositories/offimg/results/Clip_ViT-B-16/prompts.p'\n",
    "\n",
    "files = readoffendingimages_csv(dir_path_csv,t=.8)\n",
    "print(len(files))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.ndimage\n",
    "\n",
    "def gaussian_blur(sharp_image, sigma):\n",
    "    # Filter channels individually to avoid gray scale images\n",
    "    blurred_image_r = scipy.ndimage.filters.gaussian_filter(sharp_image[:, :, 0], sigma=sigma)\n",
    "    blurred_image_g = scipy.ndimage.filters.gaussian_filter(sharp_image[:, :, 1], sigma=sigma)\n",
    "    blurred_image_b = scipy.ndimage.filters.gaussian_filter(sharp_image[:, :, 2], sigma=sigma)\n",
    "    blurred_image = np.dstack((blurred_image_r, blurred_image_g, blurred_image_b))\n",
    "    return blurred_image\n",
    "\n",
    "def uniform_blur(sharp_image, uniform_filter_size):\n",
    "    # The multidimensional filter is required to avoid gray scale images\n",
    "    multidim_filter_size = (uniform_filter_size, uniform_filter_size, 1)\n",
    "    blurred_image = scipy.ndimage.filters.uniform_filter(sharp_image, size=multidim_filter_size)\n",
    "    return blurred_image\n",
    "\n",
    "def blur_image_locally(sharp_image, mask, use_gaussian_blur, gaussian_sigma, uniform_filter_size):\n",
    "\n",
    "    one_values_f32 = np.full(sharp_image.shape, fill_value=1.0, dtype=np.float32)\n",
    "    sharp_image_f32 = sharp_image.astype(dtype=np.float32)\n",
    "    sharp_mask_f32 = mask.astype(dtype=np.float32)\n",
    "\n",
    "    if use_gaussian_blur:\n",
    "        blurred_image_f32 = gaussian_blur(sharp_image_f32, sigma=gaussian_sigma)\n",
    "        blurred_mask_f32 = gaussian_blur(sharp_mask_f32, sigma=gaussian_sigma)\n",
    "\n",
    "    else:\n",
    "        blurred_image_f32 = uniform_blur(sharp_image_f32, uniform_filter_size)\n",
    "        blurred_mask_f32 = uniform_blur(sharp_mask_f32, uniform_filter_size)\n",
    "\n",
    "    blurred_mask_inverted_f32 = one_values_f32 - blurred_mask_f32\n",
    "    weighted_sharp_image = np.multiply(sharp_image_f32, blurred_mask_f32)\n",
    "    weighted_blurred_image = np.multiply(blurred_image_f32, blurred_mask_inverted_f32)\n",
    "    locally_blurred_image_f32 = weighted_sharp_image + weighted_blurred_image\n",
    "\n",
    "    locally_blurred_image = locally_blurred_image_f32.astype(dtype=np.uint8)\n",
    "\n",
    "    return locally_blurred_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_id = 'ILSVRC2012_val_00021081' # see LARGE DATASETS: A PYRRHIC WIN FOR COMPUTER VISION?\n",
    "#image_id = files[0]\n",
    "#img_birhane_val = Image(filename=imagenet_path_dict[image_id], width=600, height=600)\n",
    "\n",
    "img_birhane_val = PIL.Image.open(imagenet_path_dict[image_id])\n",
    "img_birhane_val = img_birhane_val.convert(\"RGB\")\n",
    "img_birhane_val = img_birhane_val.resize((256,256), PIL.Image.ANTIALIAS)\n",
    "sharp_image = np.asarray(img_birhane_val)\n",
    "print(sharp_image.shape)\n",
    "height, width, channels = sharp_image.shape\n",
    "#print(height, width)\n",
    "sharp_mask = np.full((height, width, channels), fill_value=1)\n",
    "sharp_mask[int(height / 4)-30: int(3 * height / 4)-60, int(width / 4)-10: int(3 * width / 4)-40, :] = 0\n",
    "\n",
    "result = blur_image_locally(\n",
    "        sharp_image,\n",
    "        sharp_mask,\n",
    "        use_gaussian_blur=False,\n",
    "        gaussian_sigma=31,\n",
    "        uniform_filter_size=10)\n",
    "plt.figure(figsize=(8,6), dpi=400)\n",
    "plt.imshow(result)\n",
    "ax = plt.gca()\n",
    "ax.axis('off')\n",
    "plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import PIL\n",
    "import matplotlib.pyplot as plt\n",
    "#Image(filename=imagenet_path_dict[files[0]], width=600, height=600)\n",
    "if csv_tag == 'imagenet_train':\n",
    "    images = [imagenet_path_dict_train[file_] for file_ in files]\n",
    "    print('#Images found offending', len(images))\n",
    "elif 'imagenet21k' in csv_tag:\n",
    "    images = [imagenet_path_dict_21kval[file_] for file_ in files]\n",
    "    print('#Images found offending', len(images))\n",
    "else:\n",
    "    images = [imagenet_path_dict[file_] for file_ in files]\n",
    "    print('#Images found offending', len(images))\n",
    "    \n",
    "\n",
    "def show_images_by_idx(idx):\n",
    "    img = PIL.Image.open(images[idx])\n",
    "    img = img.convert(\"RGB\")\n",
    "    img = img.resize((256,256), PIL.Image.ANTIALIAS)\n",
    "    img = np.asarray(img)\n",
    "    plt.imshow(np.asarray(img))\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "\n",
    "def show_images_by_img(img):\n",
    "    img = PIL.Image.open(img)\n",
    "    img = img.convert(\"RGB\")\n",
    "    img = img.resize((256,256), PIL.Image.ANTIALIAS)\n",
    "    img = np.asarray(img)\n",
    "    plt.imshow(np.asarray(img))\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imagenet_dir_class = dict()\n",
    "\n",
    "if \"21k\" in csv_tag:\n",
    "    with open(\"/workspace/datasets/imagenet21k/imagenet21k_wordnet_lemmas.txt\", \"r\") as file_classes:\n",
    "        with open(\"/workspace/datasets/imagenet21k/imagenet21k_wordnet_ids.txt\", \"r\") as file_dirs:\n",
    "            for x in file_dirs:\n",
    "                imgnet_dir = x.replace('\\n', '')\n",
    "                imgnet_class = file_classes.readline()\n",
    "                imagenet_dir_class[imgnet_dir] = imgnet_class.replace('\\n', '').split(',')[0]\n",
    "else:\n",
    "    import scipy.io\n",
    "    mat = scipy.io.loadmat(\"/workspace/datasets/imagenet_t/meta.mat\")\n",
    "    tmp = [(e[0][1], e[0][2]) for e in mat[\"synsets\"]]\n",
    "    for t in tmp:\n",
    "        imagenet_dir_class[t[0].item()] = t[1].item().split(',')[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import PIL\n",
    "from tqdm import tqdm\n",
    "from mpl_toolkits.axes_grid1 import ImageGrid\n",
    "from PIL import ImageFilter\n",
    "import os\n",
    "\n",
    "cherry_pick = True\n",
    "if cherry_pick:\n",
    "    fig = plt.figure(figsize=(16., 16.), dpi=800) # 16,36\n",
    "    #rows, cols = 100, 12 # 15, 6\n",
    "    rows, cols = 2, 8 # 15, 6\n",
    "    #cherry picked from 100*12\n",
    "    # !!! selected based on checkpoint: 0.50_1.50_3.50/1627314107.2788103/8 imagenet_val\n",
    "    cherry_picks = {\n",
    "        2: [4,5,9],\n",
    "        4: [8],\n",
    "        5: [8,11],\n",
    "        6: [2,3,6,7],\n",
    "        8: [3],\n",
    "        11: [10],\n",
    "        21: [10],\n",
    "        33: [6],\n",
    "        38: [3],\n",
    "        44: [6],\n",
    "        #Row 38: 3?\n",
    "        #Row 44: 6?\n",
    "        46: [12],\n",
    "        67: [6],\n",
    "        82: [10]\n",
    "    }\n",
    "    #cherry_picks = [(row-1)*12+4)-1]\n",
    "    cherry_picks_img = []\n",
    "    for c_row in list(cherry_picks.keys()):\n",
    "        for c_col in cherry_picks[c_row]:\n",
    "            cherry_picks_img.append(((c_row-1)*12)+c_col-1)\n",
    "    selected_images_idx = cherry_picks_img\n",
    "else:\n",
    "    fig = plt.figure(figsize=(16,36)) # 16,36\n",
    "    rows, cols = 8, 6\n",
    "    \n",
    "    start_idx = 0#1500\n",
    "    selected_images_idx = list(range(start_idx,len(images)))\n",
    "\n",
    "grid = ImageGrid(fig, 111,  # similar to subplot(111)\n",
    "                 nrows_ncols=(rows, cols),  # creates 2x2 grid of axes\n",
    "                 axes_pad=0.1,  # pad between axes in inch.\n",
    "                 )\n",
    "\n",
    "show_one_per_class = True\n",
    "if show_one_per_class:\n",
    "    classes_seen = []\n",
    "    selected_images_per_class_idx = []\n",
    "    for im_idx in selected_images_idx:\n",
    "        image_class = os.path.basename(os.path.dirname(images[im_idx]))\n",
    "        if not image_class in classes_seen:\n",
    "            selected_images_per_class_idx.append(im_idx)\n",
    "            classes_seen.append(image_class)\n",
    "    selected_images_idx = selected_images_per_class_idx\n",
    "    print('Per class #samples:', len(selected_images_idx))\n",
    "#for ax, im in tqdm(zip(grid, images[:rows*cols])):\n",
    "for j, (i, ax, im_idx) in tqdm(enumerate(zip(list(range(rows*cols)), grid, selected_images_idx[:rows*cols]))):\n",
    "    #print(i+1, images[im_idx])\n",
    "    # Iterating over the grid returns the Axes.\n",
    "    t = plt.text(0.5, 0.5, 'text', transform=ax.transAxes, fontsize=30)\n",
    "    t.set_bbox(dict(facecolor='red', alpha=0.5, edgecolor='red'))\n",
    "    imgnet_dirname = os.path.basename(os.path.dirname(images[im_idx]))\n",
    "    #print(im_idx)\n",
    "    img = PIL.Image.open(images[im_idx])\n",
    "    \n",
    "    #img = PIL.Image.open(im)\n",
    "    img = img.convert(\"RGB\")\n",
    "    img = img.resize((256,256), PIL.Image.ANTIALIAS)\n",
    "    #img = img.filter(ImageFilter.GaussianBlur(5))\n",
    "    #img = img.filter(ImageFilter.MinFilter(3))\n",
    "    img = np.asarray(img)\n",
    "    # enable blur\n",
    "    if cherry_pick:\n",
    "        if j == 0 or j == 4 or j == 6 or j == 10 or j == 14 or j ==15:\n",
    "            img = uniform_blur(img, 20)\n",
    "    ax.imshow(np.asarray(img))\n",
    "    text_size=12#20\n",
    "    #2,8\n",
    "    t = ax.text(10., 243., imagenet_dir_class[imgnet_dirname], size=text_size, ha=\"left\", va='bottom', color=\"w\")\n",
    "    #4,4\n",
    "    #t = ax.text(5., 250., imagenet_dir_class[imgnet_dirname], size=text_size, ha=\"left\", va='bottom', color=\"w\")\n",
    "    #t = plt.text(0.5, 0.5, 'text', transform=ax.transAxes, fontsize=30)\n",
    "    t.set_bbox(dict(facecolor='black', alpha=0.3, edgecolor='black'))\n",
    "for ax in grid:\n",
    "    ax.axis('off')\n",
    "\n",
    "\n",
    "#plt.savefig(os.path.join('/workspace/datasets/results/', f'offending_images_16.png'))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import PIL\n",
    "from tqdm import tqdm\n",
    "from mpl_toolkits.axes_grid1 import ImageGrid\n",
    "from PIL import ImageFilter\n",
    "import os\n",
    "step_size = 300\n",
    "if True:\n",
    "    for images_set_idx in range(0, len(images), step_size):\n",
    "\n",
    "        fig = plt.figure(figsize=(32., 2*136.)) # 16,36\n",
    "        rows, cols = 20, 15 #8, 6\n",
    "\n",
    "        start_idx = images_set_idx\n",
    "        selected_images_idx = list(range(start_idx,len(images)))\n",
    "\n",
    "        grid = ImageGrid(fig, 111,  # similar to subplot(111)\n",
    "                         nrows_ncols=(rows, cols),  # creates 2x2 grid of axes\n",
    "                         axes_pad=0.1,  # pad between axes in inch.\n",
    "                         )\n",
    "\n",
    "        show_one_per_class = False\n",
    "        if show_one_per_class:\n",
    "            classes_seen = []\n",
    "            selected_images_per_class_idx = []\n",
    "            for im_idx in selected_images_idx:\n",
    "                image_class = os.path.basename(os.path.dirname(images[im_idx]))\n",
    "                if not image_class in classes_seen:\n",
    "                    selected_images_per_class_idx.append(im_idx)\n",
    "                    classes_seen.append(image_class)\n",
    "            selected_images_idx = selected_images_per_class_idx\n",
    "            print('Per class #samples:', len(selected_images_idx))\n",
    "        #for ax, im in tqdm(zip(grid, images[:rows*cols])):\n",
    "        for i, ax, im_idx in tqdm(zip(list(range(rows*cols)), grid, selected_images_idx[:rows*cols])):\n",
    "            #print(i+1, images[im_idx])\n",
    "            # Iterating over the grid returns the Axes.\n",
    "            t = plt.text(0.5, 0.5, 'text', transform=ax.transAxes, fontsize=30)\n",
    "            t.set_bbox(dict(facecolor='red', alpha=0.5, edgecolor='red'))\n",
    "            imgnet_dirname = os.path.basename(os.path.dirname(images[im_idx]))\n",
    "            #print(im_idx)\n",
    "            img = PIL.Image.open(images[im_idx])\n",
    "            #img = PIL.Image.open(im)\n",
    "            img = img.convert(\"RGB\")\n",
    "            img = img.resize((256,256), PIL.Image.ANTIALIAS)\n",
    "            #img = img.filter(ImageFilter.GaussianBlur(5))\n",
    "            #img = img.filter(ImageFilter.MinFilter(3))\n",
    "            img = np.asarray(img)\n",
    "            ax.imshow(np.asarray(img))\n",
    "            text_size=12#20\n",
    "            #2,8\n",
    "            t = ax.text(10., 243., imagenet_dir_class[imgnet_dirname], size=text_size, ha=\"left\", va='bottom', color=\"w\")\n",
    "            #4,4\n",
    "            #t = ax.text(5., 250., imagenet_dir_class[imgnet_dirname], size=text_size, ha=\"left\", va='bottom', color=\"w\")\n",
    "            #t = plt.text(0.5, 0.5, 'text', transform=ax.transAxes, fontsize=30)\n",
    "            t.set_bbox(dict(facecolor='black', alpha=0.3, edgecolor='black'))\n",
    "        for ax in grid:\n",
    "            ax.axis('off')\n",
    "\n",
    "        #plt.tight_layout()\n",
    "        plt.savefig(os.path.join('/workspace/datasets/results/', f'offending_images_{csv_tag}_{start_idx}-{start_idx+step_size-1}.png'),bbox_inches='tight')\n",
    "        plt.close()\n",
    "        #1 / 0\n",
    "        #plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import PIL\n",
    "from tqdm import tqdm\n",
    "from mpl_toolkits.axes_grid1 import ImageGrid\n",
    "from PIL import ImageFilter\n",
    "import os\n",
    "\n",
    "selected_images_indices = list(range(len(images)))\n",
    "print(len(selected_images_indices))\n",
    "found_images_by_wordnet_id = dict()\n",
    "for im_idx in selected_images_indices:\n",
    "    imgnet_dirname = os.path.basename(os.path.dirname(images[im_idx]))\n",
    "    if imgnet_dirname in list(found_images_by_wordnet_id.keys()):\n",
    "        found_images_by_wordnet_id[imgnet_dirname]['value'] += 1\n",
    "        found_images_by_wordnet_id[imgnet_dirname]['idx'].append(im_idx)\n",
    "    else:\n",
    "        found_images_by_wordnet_id[imgnet_dirname] = {'value': 1, 'idx': list([im_idx])}\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#show_images_by_img(images[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "found_images_by_wordnet_id_classnames = dict()\n",
    "for imageclass_id in list(found_images_by_wordnet_id.keys()):\n",
    "    #print(imagenet_dir_class[imageclass_id], found_images_by_wordnet_id[imageclass_id])\n",
    "    found_images_by_wordnet_id_classnames[imagenet_dir_class[imageclass_id]] = {'value': None, 'idx': None}\n",
    "    found_images_by_wordnet_id_classnames[imagenet_dir_class[imageclass_id]]['value'] = found_images_by_wordnet_id[imageclass_id]['value']\n",
    "    found_images_by_wordnet_id_classnames[imagenet_dir_class[imageclass_id]]['idx'] = found_images_by_wordnet_id[imageclass_id]['idx']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(found_images_by_wordnet_id_classnames['squirrel monkey'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#show_images_by_idx(found_images_by_wordnet_id_classnames['muzzle']['idx'][-1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "print(len(images))\n",
    "print('total classes with possible polluted images:',len(list(found_images_by_wordnet_id.keys())))\n",
    "found_images_by_wordnet_id_list = list()\n",
    "for imageclass_id in list(found_images_by_wordnet_id.keys()):\n",
    "    #print(imagenet_dir_class[imageclass_id], found_images_by_wordnet_id[imageclass_id])\n",
    "    print(imagenet_dir_class[imageclass_id])\n",
    "    found_images_by_wordnet_id_list.append((imagenet_dir_class[imageclass_id], found_images_by_wordnet_id[imageclass_id]))\n",
    "found_images_by_wordnet_id_list.sort(key=lambda tup: tup[1]['value'])\n",
    "\n",
    "plt_list = found_images_by_wordnet_id_list[:60]\n",
    "print(plt_list)\n",
    "bars = [e[0] for e in plt_list]\n",
    "y_pos = [e[1]['value'] for e in plt_list]\n",
    "plt.figure(figsize=(18,6))\n",
    "bar_color = sns.color_palette('deep')[3]\n",
    "bar_color_edge = sns.color_palette('deep')[-3]\n",
    "plt.bar(bars, \n",
    "        y_pos, width=0.35,\n",
    "        alpha=0.8, color=bar_color, edgecolor=bar_color_edge, linewidth=1.)\n",
    "ticks = np.arange(len(plt_list), dtype=float)\n",
    "ticks -= 0.5\n",
    "plt.xticks(rotation=90)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#show_images_by_idx(found_images_by_wordnet_id_classnames['squirrel monkey']['idx'][0])\n",
    "#show_images_by_idx(found_images_by_wordnet_id_classnames['redbone']['idx'][0])\n",
    "#show_images_by_idx(found_images_by_wordnet_id_classnames['espresso']['idx'][0])\n",
    "for e in found_images_by_wordnet_id_list:\n",
    "    min_cnt = 0\n",
    "    max_cnt = 3\n",
    "    if len(found_images_by_wordnet_id_classnames[e[0]]['idx']) >= min_cnt and len(found_images_by_wordnet_id_classnames[e[0]]['idx']) < max_cnt:\n",
    "        print(e[0])\n",
    "        for i in found_images_by_wordnet_id_classnames[e[0]]['idx']:\n",
    "            show_images_by_idx(i)\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in found_images_by_wordnet_id_classnames['plastic bag']['idx']:\n",
    "    show_images_by_idx(i)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt_list = found_images_by_wordnet_id_list[:30]\n",
    "print(plt_list)\n",
    "bars = [e[0] for e in plt_list]\n",
    "y_pos = [e[1] for e in plt_list]\n",
    "plt.figure(figsize=(18,6))\n",
    "bar_color = sns.color_palette('deep')[3]\n",
    "bar_color_edge = sns.color_palette('deep')[-3]\n",
    "plt.bar(bars, \n",
    "        y_pos, width=0.35,\n",
    "        alpha=0.8, color=bar_color, edgecolor=bar_color_edge, linewidth=1.)\n",
    "ticks = np.arange(len(plt_list), dtype=float)\n",
    "ticks -= 0.5\n",
    "plt.xticks(rotation=90)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show embedding space of SMID images (CLIP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/workspace/repositories/MoralVisionModels')\n",
    "from main.tune_clip import setup_model_clip_prompt, setup_dataset\n",
    "language_model = 'Clip_ViT-B/16'\n",
    "model_clip = setup_model_clip_prompt(language_model)\n",
    "dataset = setup_dataset(model_clip.preprocess, test_size=None, train_size=None, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from main.tune_clip import get_dataloaders\n",
    "dataloader, _ = get_dataloaders(dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import torch\n",
    "prompt_path = '/wworkspace/repositories/offimg/results/Clip_ViT-B-16/prompts.p'\n",
    "inv_normalize_clip = Normalize(\n",
    "    mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711],\n",
    "    std=[1 / 0.26862954, 1 / 0.26130258, 1 / 0.27577711]\n",
    ")\n",
    "def run_dataset(model_, prompts=False, images=False):\n",
    "    embs_ = []\n",
    "    labels_ = []\n",
    "    means_ = []\n",
    "    imgs_ = []\n",
    "    nn_images_ = None\n",
    "    for X_batch, y_batch, means_batch, conf_batch in tqdm(dataloader):\n",
    "        X_batch = X_batch.to('cuda')\n",
    "        if images:\n",
    "            imgs_ += inv_normalize_clip(X_batch.detach().cpu()).permute(0,2,3,1).detach().cpu().numpy().tolist()\n",
    "        emb = model_.encode(X_batch).detach().cpu().numpy().tolist()\n",
    "        labels_ +=  y_batch.detach().cpu().numpy().tolist()\n",
    "        means_ +=  means_batch.detach().cpu().numpy().tolist()\n",
    "        embs_ += emb\n",
    "    if prompts:\n",
    "        prompts = pickle.load(open(prompt_path, 'rb'))\n",
    "        prompts_init = model_.prompts.detach().cpu().numpy().tolist()\n",
    "        print(len(embs_))\n",
    "        print('images', len(imgs_))\n",
    "        \n",
    "        # search for nn\n",
    "        #print(torch.FloatTensor(prompts[0]).repeat(len(embs_), 1).shape)\n",
    "        #print(torch.FloatTensor(embs_).shape)\n",
    "        if len(imgs_) > 0:\n",
    "            p1_init_sim = torch.nn.functional.cosine_similarity(torch.FloatTensor(prompts_init[0]).repeat(len(embs_), 1), \n",
    "                                                           torch.FloatTensor(embs_), dim=1, eps=1e-8)\n",
    "            p2_init_sim = torch.nn.functional.cosine_similarity(torch.FloatTensor(prompts_init[1]).repeat(len(embs_), 1), \n",
    "                                                           torch.FloatTensor(embs_), dim=1, eps=1e-8)\n",
    "            p1_sim = torch.nn.functional.cosine_similarity(torch.FloatTensor(prompts[0]).repeat(len(embs_), 1), \n",
    "                                                           torch.FloatTensor(embs_), dim=1, eps=1e-8)\n",
    "            p2_sim = torch.nn.functional.cosine_similarity(torch.FloatTensor(prompts[1]).repeat(len(embs_), 1), \n",
    "                                                           torch.FloatTensor(embs_), dim=1, eps=1e-8)\n",
    "            nn_images_ = (imgs_[torch.argmax(p1_init_sim)],imgs_[torch.argmax(p2_init_sim)],\n",
    "                       imgs_[torch.argmax(p1_sim)],imgs_[torch.argmax(p2_sim)])\n",
    "        \n",
    "        embs_.append(prompts[0])\n",
    "        embs_.append(prompts[1])\n",
    "        embs_.append(prompts_init[0])\n",
    "        embs_.append(prompts_init[1])\n",
    "        #print(prompts[0])\n",
    "        #print('----')\n",
    "        #print(prompts[1])\n",
    "        prompt_diff = np.array(prompts[0]-prompts[1])\n",
    "        print(np.max(np.abs(prompt_diff)))\n",
    "        print(np.max(prompt_diff), np.min(prompt_diff))\n",
    "        print(np.mean(prompt_diff))\n",
    "        print(np.sum(prompt_diff**2>.5))\n",
    "\n",
    "        labels_ += [2,3,4,5]\n",
    "        print(len(embs_))\n",
    "    return embs_, labels_, means_, imgs_, nn_images_\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# nearest images of prompts (init1, init2, opt1, opt2)\n",
    "_, _, _, _, nn_images = run_dataset(model_clip, prompts=True, images=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms import Normalize\n",
    "\n",
    "img = np.array(np.array(nn_images[3]) * 255,\n",
    "                            dtype=np.int32)\n",
    "plt.imshow(img)\n",
    "ax = plt.gca()\n",
    "ax.axis('off')\n",
    "ax.grid(False)\n",
    "\n",
    "# Hide axes ticks\n",
    "#ax.set_xticks([])\n",
    "#ax.set_yticks([])\n",
    "#ax.set_zticks([])\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embs_clip, labels_clip, means_clip, imgs_clip, _ = run_dataset(model_clip, prompts=True)\n",
    "print(len(embs_clip))\n",
    "means_clip = [int(round(l)) for l in means_clip]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means_clip += ['prompt','prompt']\n",
    "means_clip += ['promptInit','promptInit']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.unique(means_clip)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "from sklearn.manifold import TSNE\n",
    "tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=500)\n",
    "tsne_results = tsne.fit_transform(embs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "df = pd.DataFrame(data=tsne_results, columns=[\"tsne-2d-one\", \"tsne-2d-two\"])\n",
    "df.insert(2, \"y\", labels, True)\n",
    "df.insert(3, \"y_means\", means, True)\n",
    "plt.figure(figsize=(16,10))\n",
    "sns.scatterplot(\n",
    "    x=\"tsne-2d-one\", y=\"tsne-2d-two\",\n",
    "    hue=\"y\",\n",
    "    palette=sns.color_palette(\"deep\", 6),\n",
    "    data=df,\n",
    "    legend=\"full\",\n",
    "    alpha=0.7,\n",
    "    s=[40]*(len(embs)-4) + [200, 100] + [200, 100]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "pca = PCA(n_components=2)\n",
    "pca.fit(embs_clip[:-4])\n",
    "pca_results_clip = pca.transform(embs_clip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'Explained variation per principal component: {pca.explained_variance_ratio_}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#TODO plot some imagenet projections\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline\n",
    "import matplotlib.pylab as pylab\n",
    "\n",
    "\n",
    "def plot_pca(data_, labels_, means_, prompts=False):\n",
    "    fontsize=30\n",
    "    figsize = (8, 6)\n",
    "    params = {\n",
    "        'legend.fontsize': 'xx-large',\n",
    "        'legend.loc': 'upper left',\n",
    "        'figure.figsize': figsize,\n",
    "        'axes.labelsize': fontsize,\n",
    "        'axes.titlesize':fontsize,\n",
    "        'xtick.labelsize':fontsize,\n",
    "        'ytick.labelsize':fontsize,\n",
    "        \"text.usetex\": True,\n",
    "        \"font.family\": \"sans-serif\",\n",
    "        \"font.sans-serif\": [\"Helvetica\"]\n",
    "    }\n",
    "    pylab.rcParams.update(params)\n",
    "    plt.rcParams['axes.xmargin'] = 0.05\n",
    "\n",
    "    sns.set_style(\"whitegrid\")\n",
    "    df = pd.DataFrame(data=data_, columns=['PC1', 'PC2'])\n",
    "    df.insert(2, \"y\", labels_, True)\n",
    "    df.insert(3, \"Moral\\n rating\", means_, True)\n",
    "\n",
    "    fig = plt.figure(figsize=figsize, dpi=1200)\n",
    "    s = [40]*(len(means_))\n",
    "    markers = [0]*(len(means_))\n",
    "    if prompts:\n",
    "        s = [40]*(len(means_) -4) + [400, 400] + [200, 200]\n",
    "        markers = ['Sample']*(len(means_) -4) + ['Prompt', 'Prompt'] + ['PromptInit', 'PromptInit']\n",
    "    \n",
    "    df.insert(3, \"Type\", markers, True)\n",
    "    if prompts:\n",
    "        axes = sns.scatterplot(\n",
    "            x=\"PC1\", y=\"PC2\",\n",
    "            hue=\"Moral\\n rating\",\n",
    "            palette=sns.color_palette(\"deep\", len(np.unique(means_))),\n",
    "            data=df,\n",
    "            legend=\"full\",\n",
    "            alpha=0.5,\n",
    "            s=s,\n",
    "            style=\"Type\",\n",
    "        )\n",
    "    else:\n",
    "        axes = sns.scatterplot(\n",
    "            x=\"PC1\", y=\"PC2\",\n",
    "            hue=\"Moral\\n rating\",\n",
    "            palette=sns.color_palette(\"deep\", len(np.unique(means_))),\n",
    "            data=df,\n",
    "            legend=\"full\",\n",
    "            alpha=0.5,\n",
    "            s=s,\n",
    "            #style=\"Type\",\n",
    "        )\n",
    "        \n",
    "    leg = axes.get_legend()\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.setp(leg.get_texts(), fontsize=fontsize-6) # for legend text\n",
    "    plt.setp(leg.get_title(), fontsize=fontsize-6) # for legend title\n",
    "    fig.savefig('/workspace/datasets/results/emb_space3.svg')\n",
    "    #del leg.texts[5]\n",
    "    #leg.texts = leg.texts\n",
    "    print(leg.texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_pca(pca_results_clip[:-4], labels_clip[:-4], means_clip[:-4], prompts=False)\n",
    "#plot_pca(pca_results_clip[:], labels_clip[:], means_clip[:], prompts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_clip=np.array(imgs_clip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if False:\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import seaborn as sns\n",
    "    from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n",
    "    #%matplotlib inline\n",
    "    #%matplotlib nbagg\n",
    "    %matplotlib notebook\n",
    "\n",
    "\n",
    "    df = pd.DataFrame(data=pca_results, columns=['pca-one', 'pca-two'])\n",
    "    df.insert(2, \"y\", labels, True)\n",
    "    df.insert(3, \"y_means\", means, True)\n",
    "\n",
    "    fig = plt.figure()\n",
    "    ax = fig.add_subplot(111)\n",
    "    #line, = ax.plot(pca_results[:, 0],pca_results[:, 1], ls=\"\", marker=\"o\")\n",
    "\n",
    "    sns.scatterplot(\n",
    "        x=\"pca-one\", y=\"pca-two\",\n",
    "        hue=\"y_means\",\n",
    "        palette=sns.color_palette(\"deep\", 8),\n",
    "        data=df,\n",
    "        legend=\"full\",\n",
    "        alpha=0.8,\n",
    "        s=[40]*(len(embs)-4) + [400, 400] + [200, 200]\n",
    "    )\n",
    "\n",
    "    im = OffsetImage(images[0,:,:,:], zoom=.5)\n",
    "\n",
    "    xybox=(50., 50.)\n",
    "    ab = AnnotationBbox(im, (0,0), xybox=xybox, xycoords='data',\n",
    "            boxcoords=\"offset points\",  pad=0.3,  arrowprops=dict(arrowstyle=\"->\"))\n",
    "    # add it to the axes and make it invisible\n",
    "    ax.add_artist(ab)\n",
    "    ab.set_visible(False)\n",
    "\n",
    "    def hover(event):\n",
    "        # if the mouse is over the scatter points\n",
    "        if line.contains(event)[0]:\n",
    "            # find out the index within the array from the event\n",
    "            ind, = line.contains(event)[1][\"ind\"]\n",
    "            im.set_data(images[ind,:,:,:]) #if the dataset is too large to load into memory, can instead replace this command with a realtime load\n",
    "\n",
    "            # get the figure size\n",
    "            w,h = fig.get_size_inches()*fig.dpi\n",
    "            ws = (event.x > w/2.)*-1 + (event.x <= w/2.) \n",
    "            hs = (event.y > h/2.)*-1 + (event.y <= h/2.)\n",
    "            # if event occurs in the top or right quadrant of the figure,\n",
    "            # change the annotation box position relative to mouse.\n",
    "            ab.xybox = (xybox[0]*ws, xybox[1]*hs)\n",
    "            # make annotation box visible\n",
    "            ab.set_visible(True)\n",
    "            # place it at the position of the hovered scatter point\n",
    "            ab.xy =(x[ind], y[ind])\n",
    "            # set the image corresponding to that point\n",
    "        else:\n",
    "            #if the mouse is not over a scatter point\n",
    "            ab.set_visible(False)\n",
    "        fig.canvas.draw_idle()\n",
    "\n",
    "    # add callback for mouse moves\n",
    "    fig.canvas.mpl_connect('motion_notify_event', hover)           \n",
    "    #plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show embedding space of SMID images (Imagenet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/workspace/repositories/offimg')\n",
    "from main.tune_clip import setup_model_imagenet_probe, setup_dataset\n",
    "language_model = 'resnet50'\n",
    "model_imagenet = setup_model_imagenet_probe(language_model, feature_extraction_forward=True)\n",
    "dataset = setup_dataset(model_imagenet.preprocess, test_size=None, train_size=None, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from main.tune_clip import get_dataloaders\n",
    "dataloader, _ = get_dataloaders(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embs_imagenet, labels_imagenet, means_imagenet, imgs_imagenet = run_dataset(model_imagenet, prompts=False)\n",
    "print(len(embs_imagenet))\n",
    "means_imagenet = [int(round(l)) for l in means_imagenet]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "pca = PCA(n_components=2)\n",
    "pca.fit(embs_imagenet)\n",
    "pca_results_imagenet = pca.transform(embs_imagenet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_pca(pca_results_imagenet, labels_imagenet, means_imagenet)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/workspace/repositories/offimg')\n",
    "from main.tune_clip import setup_model_clip_prompt, setup_dataset\n",
    "language_model = 'Clip_ViT-B/32'\n",
    "model = setup_model_clip_prompt(language_model)\n",
    "dataset = setup_dataset(model.preprocess, test_size=None, train_size=None, verbose=False, t_low=2.99, t_high=3.00)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from main.tune_clip import get_dataloaders\n",
    "dataloader, _ = get_dataloaders(dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import pickle\n",
    "embs = []\n",
    "labels = []\n",
    "means = []\n",
    "for X_batch, y_batch, means_batch, conf_batch in tqdm(dataloader):\n",
    "    X_batch = X_batch.to('cuda')\n",
    "    \n",
    "    #emb = model.encode(X_batch).detach().cpu().numpy().tolist()\n",
    "    labels +=  y_batch.detach().cpu().numpy().tolist()\n",
    "    means +=  means_batch.detach().cpu().numpy().tolist()\n",
    "    #embs += emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline\n",
    "import matplotlib.pylab as pylab\n",
    "params = {\n",
    "    'legend.fontsize': 'x-large',\n",
    "    'figure.figsize': (15, 10),\n",
    "    'axes.labelsize': 'xx-large',\n",
    "    'axes.titlesize':'x-large',\n",
    "    'xtick.labelsize':'xx-large',\n",
    "    'ytick.labelsize':'xx-large',\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"]\n",
    "}\n",
    "pylab.rcParams.update(params)\n",
    "plt.rcParams['axes.xmargin'] = 0.05\n",
    "\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "relative_bin_width = 0.8\n",
    "bin_width=0.25\n",
    "bar_color_edge = sns.color_palette('deep')[-3]\n",
    "bar_color_alpha = 1.\n",
    "bins = np.arange(1, 5. + bin_width, bin_width)\n",
    "plt.figure(figsize=(8,6))\n",
    "n, b, patches = plt.hist(means, bins=bins, density=False,\n",
    "                         edgecolor=bar_color_edge, alpha=bar_color_alpha, rwidth=relative_bin_width)\n",
    "\n",
    "print(n)\n",
    "bar_colors = []\n",
    "for e in b:\n",
    "    if e <= 2.5:\n",
    "        bar_colors.append(sns.color_palette('deep')[3])\n",
    "        \n",
    "    elif e < 3.5:\n",
    "        bar_colors.append(sns.color_palette('deep')[7])\n",
    "    else:\n",
    "        bar_colors.append(sns.color_palette('deep')[2])\n",
    "plt.title('Histogram of data distribution')\n",
    "plt.grid(False)\n",
    "#plt.show()\n",
    "plt.close()\n",
    "plt.clf()\n",
    "\n",
    "plt.figure(figsize=(6,4), dpi=1200)\n",
    "plt.bar(b[:-1], n, width=bin_width*relative_bin_width,\n",
    "            alpha=bar_color_alpha, color=bar_colors, edgecolor=bar_color_edge, linewidth=1.)\n",
    "plt.xticks(bins, [f'{e:.2f}' for e in b], rotation='vertical')\n",
    "#plt.xticks(fontsize=44)\n",
    "#plt.yticks(fontsize=44)\n",
    "plt.xlabel('Moral rating', fontsize=20)\n",
    "plt.grid(False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}