# -*- coding: utf-8 -*-
"""TinyImageNetLoader.ipynb

Automatically generated by Colaboratory.


"""

#loads images as 3*64*64 tensors 

#!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
#!unzip -q tiny-imagenet-200.zip

import torch
import torchvision
import numpy as np
import os, glob

import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import Dataset

from torchvision.io import read_image, ImageReadMode
from torchvision import transforms
from ..dataset_utils import *
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from pathlib import Path

class TinyImageNet200CLIP(Dataset):
    def __init__(self,data_conf):
        self.data_conf = data_conf
        self.d = self.data_conf["dimension"]
        self.num_classes = 200 
        self.f = None
        
    def __getitem__(self, index):
        
        x = self.X[index]   
        y = self.Y[index]
        
        return x, y, index 
    
    def __len__(self):

        return len(self.X)
    
    def len(self):
        return len(self.X)
    
    def get_subset(self,idcs):
        idcs = np.array(idcs) 
        X = None
        Y = None 
        if(self.X is not None):
            X = self.X[idcs] 
        if(self.Y is not None):
            Y = self.Y[idcs] 
            return CustomTensorDataset(X=X,Y=Y,num_classes = self.num_classes, d=self.d,transform=None)

    def compute_embeddings(self):
        print("computing embeddings")
        data_dir  = self.data_conf['data_path']
        data_conf = self.data_conf
        # Define the model ID
        model_ID = data_conf['emb_model']
        device = "cuda" if torch.cuda.is_available() else "cpu"

        model = CLIPModel.from_pretrained(model_ID).to(device)
 	    # Get the processor
        processor = CLIPProcessor.from_pretrained(model_ID)
        #print("here")

        id_dict = {}
        os.path.join(data_dir, 'wnids.txt')
        for i, line in enumerate(open(os.path.join(data_dir, 'wnids.txt'), 'r')):
            id_dict[line.rstrip()] = i

        # load training data
        train_filenames = glob.glob(os.path.join(data_dir, 'train', '*', '*', '*.JPEG'))
        lst_train_images = []
        lst_train_labels = []
        #k = 5000

        for img_path in train_filenames:
            image = Image.open(img_path).convert("RGB")
            image_p = processor( text = None, images = image,
                                    return_tensors="pt")["pixel_values"].to(device)
            embedding = model.get_image_features(image_p)
                
            label = id_dict[img_path.split('/')[-1].split('_')[0]]
            lst_train_images.append(embedding.detach().cpu())
            lst_train_labels.append(label)
        
        # load validation data
        #print('here')
        val_fname_cls_id = {}
        for i, line in enumerate(open(os.path.join(data_dir, 'val', 'val_annotations.txt'), 'r')):
            a = line.split('\t')
            fname, cls_id = a[0],a[1]
            val_fname_cls_id[a[0]] = a[1]

        val_filenames = glob.glob(os.path.join(data_dir, 'val', 'images', '*.JPEG'))
        lst_val_images = [] 
        lst_val_labels = []
        for img_path in val_filenames:
            image = Image.open(img_path).convert("RGB")
            image_p = processor( text = None, images = image,
                                    return_tensors="pt")["pixel_values"].to(device)
            embedding = model.get_image_features(image_p)

            label = val_fname_cls_id[img_path.split('/')[-1]]
            lst_val_images.append(embedding.detach().cpu())
            lst_val_labels.append(id_dict[label])
        
        all_images = lst_train_images + lst_val_images
        all_labels = lst_train_labels + lst_val_labels

        self.X  = torch.stack(all_images).squeeze()
        self.Y = torch.Tensor(all_labels).long()
        self.X_val = torch.stack(lst_val_images)

        self.Y_val = torch.Tensor(lst_val_labels).long()
        self.idcs_std_train = np.arange(0,len(lst_train_images),1).astype(int)
        self.idcs_std_val   = np.arange(len(lst_train_images),len(all_images),1).astype(int)
        

    def build_dataset(self):

        # expecting path to tiny-imagenet-200 directory that contains train test and val folders
        data_dir  = self.data_conf['data_path']
        data_conf = self.data_conf
        emb_path = data_conf['emb_path']
        emb_model = data_conf['emb_model']
        ckpt_path = os.path.join(emb_path, f'tiny_imgnet_{emb_model}.pt') 

        if(data_conf['compute_emb']):
            self.compute_embeddings()
            

            ckpt_content = {
                            'X':self.X, 'Y' :self.Y , 
                            'idcs_std_train':self.idcs_std_train,
                            'idcs_std_val':self.idcs_std_val
                            }
            
            pp = os.path.join(*ckpt_path.split('/')[:-1])

            Path(pp).mkdir(parents=True, exist_ok=True)

            torch.save(ckpt_content, ckpt_path)

        else:
            ckpt_content = torch.load(ckpt_path)
            self.X = ckpt_content['X']
            self.Y = ckpt_content['Y']
            self.idcs_std_train = ckpt_content['idcs_std_train'] 
            self.idcs_std_val = ckpt_content['idcs_std_val']

        # take 10k points at random from the training set
        idcs_aug_val = list( np.random.choice(self.idcs_std_train, 10000, replace=False) ) 

        # add these to the std val set, to increase its size.
        self.idcs_std_val   = list(self.idcs_std_val) + idcs_aug_val

        # remove the selected points from std_train set.
        self.idcs_std_train = [ i for i in self.idcs_std_train if i not in idcs_aug_val ]
        
        
        self.X_val = self.X[self.idcs_std_val]
        self.Y_val = self.Y[self.idcs_std_val]
        
        self.X_train = self.X[self.idcs_std_train]
        self.Y_train = self.Y[self.idcs_std_train]
        
        #***** The labels for test set are not known
        # this is temporary fix, don't use test numbers
        self.X_test = self.X_val 
        self.Y_test = self.Y_val 
        
        
        self.ds_std_train = CustomTensorDataset(X=self.X_train,Y=self.Y_train, num_classes = self.num_classes,d=self.d,transform=None)
        self.ds_std_val   = CustomTensorDataset(X=self.X_val,Y=self.Y_val, num_classes = self.num_classes,d=self.d,transform=None)

    def get_test_datasets(self):
        X_ = self.X_test
        Y_ = self.Y_test
        return CustomTensorDataset(X=X_,Y=Y_, num_classes = self.num_classes,d=self.d,transform=None)
