from PIL import Image
import glob
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import json
from tqdm import tqdm
import torch
from eval import *

normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
standard_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224), transforms.ToTensor()])

_OBJECTNET_ROOT = ''#REDACTED
_SKETCH_ROOT = ''#REDACTED
_IMAGENET_R_ROOT = ''#REDACTED

######## OBJECTNET
### LOADER 
input_size=224

class ObjectNet(Dataset):
    def __init__(self, root=_OBJECTNET_ROOT, transform=standard_transform,
                       normalize=None, img_format='png'):
        self.root = root
        self.transform = transform
        self.normalize = normalize
        files = glob.glob(root+"/**/*."+img_format, recursive=True)
        self.pathDict = {}
        for f in files:
            self.pathDict[f.split("/")[-1]] = f
        self.imgs = list(self.pathDict.keys())
        self.loader = self.pil_loader
        with open(self.root+'mappings/folder_to_onet_id.json', 'r') as f:
            self.folder_to_onet_id = json.load(f)

    def __getitem__(self, index):
        """
        Get an image and its label.
        Args:
            index (int): Index
        Returns:
            tuple: Tuple (image, onet_id). onet_id is the ID of the objectnet class (0 to 112)
        """
        img, onet_id = self.getImage(index)
        img = self.transform(img)
        if self.normalize is not None:
            img = self.normalize(img)

        return img, onet_id

    def getImage(self, index):
        """
        Load the image and its label.
        Args:
            index (int): Index
        Return:
            tuple: Tuple (image, target). target is the image file name
        """
        filepath = self.pathDict[self.imgs[index]]
        img = self.loader(filepath)

        # crop out red border
        width, height = img.size
        cropArea = (2, 2, width-2, height-2)
        img = img.crop(cropArea)

        # map folder name to objectnet id
        folder = filepath.split('/')[-2]
        onet_id = self.folder_to_onet_id[folder]
        return (img, onet_id)

    def __len__(self):
        """Get the number of ObjectNet images to load."""
        return len(self.imgs)

    def pil_loader(self, path):
        """Pil image loader."""
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

### EVALUATION
def eval_on_objectnet(model, apply_norm, results=dict(), batch_size=16):
    if 'acc' not in results:
        dset = ObjectNet(normalize=normalizer)
        loader = DataLoader(dset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = model.to(device).eval()

        with open(f'{_OBJECTNET_ROOT}/mappings/inet_id_to_onet_id.json', 'r') as f:
            inet_id_to_onet_id = json.load(f)
        inet_id_to_onet_id = dict({int(k):v for k,v in inet_id_to_onet_id.items()})

        cc, cnt = 0, 0
        for imgs, onet_ids in tqdm(loader):
            imgs = imgs.to(device)

            with torch.no_grad():
                inet_preds = model(imgs).argmax(1).detach().cpu()
                # _, inet_preds = model(imgs).detach().cpu().topk(5,1)
            onet_preds = torch.LongTensor([-1 if p.item() not in inet_id_to_onet_id 
                                            else inet_id_to_onet_id[p.item()] for p in inet_preds])
            cc += (onet_preds == onet_ids).sum()
            cnt += imgs.shape[0]
        
        results['acc'] = (100. * cc / cnt).item()

    return results['acc'], results


######## IMAGENET-SKETCH
def eval_on_sketch(model, apply_norm, results=dict()):
    if 'acc' not in results:
        in_sketch = datasets.ImageFolder(root=_SKETCH_ROOT, transform=standard_transform)
        loader = DataLoader(in_sketch, batch_size=16, num_workers=16, shuffle=True)
        results['acc'] = standard_acc(model, loader, apply_norm)
    return results['acc'], results

######## IMAGENET-R
def eval_on_imagenet_r(model, apply_norm, results=dict()):
    ### NOTE THAT THIS EVAL IS DIFF FROM OG IMAGENET-R EVAL
    ### In the original eval, they only consider the maximally activated logit amongst the 200 available classes
    if 'acc' not in results:
        in_r_to_og_idx = [1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 
                            105, 107, 113, 122, 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 
                            178, 187, 195, 199, 203, 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 
                            263, 265, 267, 269, 276, 277, 281, 288, 289, 291, 292, 293, 296, 299, 301, 308, 309, 310, 311, 314, 
                            315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 
                            368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 
                            463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 
                            613, 617, 621, 629, 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 
                            820, 824, 833, 847, 852, 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 
                            947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988]

        in_r = datasets.ImageFolder(root=_IMAGENET_R_ROOT, transform=standard_transform)
        loader = DataLoader(in_r, batch_size=16, num_workers=16, shuffle=True)
        results['acc'] = standard_acc(model, loader, apply_norm, label_mapping=in_r_to_og_idx)
    return results['acc'], results