#!/usr/bin/env python
# coding: utf-8

# # Analysis of Composite Images
# 
# This analysis requires the imagenettev2 dataset (https://github.com/fastai/imagenette).
# 
# Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute)
# 
# Our goal was to generate pairs of images that would compete with each other when combined into composite images. Thus, we randomly paired each imagenette validation image with another validation image from a different category, and with a third image from yet another category to serve as a control comparison. These triplets (targetA, targetB, absent) where then either stacked side-by-side, or blended with 50/50 opacity to form composites. This procedure yielded 3910 unique triplets to serve as our probe set.
# 
# This notebook demonstrates composite image accuracy for side-by-side pairs for a pretrained alexnet model.

# In[79]:


import os
import torch
import pandas as pd
import numpy as np
from PIL import Image
from torchvision.models import alexnet
from torchvision.models.alexnet import AlexNet_Weights

from torchvision import transforms
from torchvision.datasets.folder import default_loader
from pdb import set_trace

device = 'cuda' if torch.cuda.is_available() else 'cpu'
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

inv_transform = transforms.Compose([
    transforms.Normalize(
        mean= [-m/s for m, s in zip(mean, std)],
        std= [1/s for s in std]
    ),
    transforms.ToPILImage(),
])

class TripletDataset(object):
    def __init__(self, root_dir, csv_file='./triplets/triplets.csv', transform=None, loader=default_loader):
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.loader = loader
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        file1 = os.path.join(self.root_dir, row.file1)
        file2 = os.path.join(self.root_dir, row.file2)
        file3 = os.path.join(self.root_dir, row.file3)
        
        img1 = self.loader(file1).convert('RGB')
        img2 = self.loader(file2).convert('RGB')
        img3 = self.loader(file3).convert('RGB')
        
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
            
        return [img1,img2,img3],[row.label1,row.label2,row.label3]
        


# 

# In[90]:


root_dir = '/home/jovyan/work/DataLocal/imagenette2/val'
dataset = TripletDataset(root_dir, transform=transform)
len(dataset)


# In[81]:


# Image.open('/home/jovyan/work/DataLocal/imagenette2/val/n03028079/n03028079_2842.JPEG')


# In[82]:


(img1,img2,img3),(label1,label2,label3) = dataset[0]
img_concat = torch.cat([img1,img2], dim=-1)
inv_transform(img_concat)


# In[83]:


model = alexnet(weights=AlexNet_Weights.DEFAULT)
model.to(device)


# In[84]:


from tqdm import tqdm

model.eval()
img1_correct = []
img2_correct = []
img3_correct = []

img1_is_max = []
img2_is_max = []
img3_is_max = []
for (img1,img2,img3),(label1,label2,label3) in tqdm(dataset):
    img1 = img1.unsqueeze(0).to(device)
    img2 = img2.unsqueeze(0).to(device)
    img3 = img3.unsqueeze(0).to(device)
    
    img_concat = torch.cat([img1,img2], dim=-1)
    
    with torch.no_grad():
        out1 = model(img1)
        out2 = model(img2)
        out3 = model(img3)
        out_concat = model(img_concat)
        
        pred1 = out1.argmax(dim=-1).item()
        pred2 = out2.argmax(dim=-1).item()
        pred3 = out3.argmax(dim=-1).item()
        pred_concat = out_concat.argmax(dim=-1).item()
        
        # accuracy for isolated images
        img1_correct.append(float(label1==pred1))
        img2_correct.append(float(label2==pred2))
        img3_correct.append(float(label3==pred3))
        
        # accuracy for concatnated image
        img1_is_max.append(float(label1==pred_concat))
        img2_is_max.append(float(label2==pred_concat))
        img3_is_max.append(float(label3==pred_concat))
        
img1_correct = np.array(img1_correct)
img2_correct = np.array(img2_correct)
img3_correct = np.array(img3_correct)

img1_is_max = np.array(img1_is_max)
img2_is_max = np.array(img2_is_max)
img3_is_max = np.array(img3_is_max)


# In[88]:


# Presented in isolation, alexnet can correctly classify these images ~75% of the time
img1_correct.mean(), img2_correct.mean(), img3_correct.mean()


# In[89]:


# side-by-side pairs (img1, img2) are correctly classified about 25% of the time
# and the output unit for the control image (img3) is almost never the most activate unit (< 1%)
img1_is_max.mean(), img2_is_max.mean(), img3_is_max.mean()


# In[ ]:




