# -*- coding: utf-8 -*-
"""TinyImageNetLoader.ipynb

Automatically generated by Colaboratory.


"""

#loads images as 3*64*64 tensors 

#!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
#!unzip -q tiny-imagenet-200.zip

import torch
import torchvision
import numpy as np
import os, glob
from tqdm import tqdm
from scipy.io import loadmat

import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import Dataset

from torchvision.io import read_image, ImageReadMode
from torchvision import transforms
from ..dataset_utils import *
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from pathlib import Path

class STL10CLIP(Dataset):
    def __init__(self,data_conf):
        self.data_conf = data_conf
        self.d = self.data_conf["dimension"]
        self.num_classes = 200 
        self.f = None
        
    def __getitem__(self, index):
        
        x = self.X[index]   
        y = self.Y[index]
        
        return x, y, index 
    
    def __len__(self):

        return len(self.X)
    
    def len(self):
        return len(self.X)
    
    def get_subset(self,idcs):
        idcs = np.array(idcs) 
        X = None
        Y = None 
        if(self.X is not None):
            X = self.X[idcs] 
        if(self.Y is not None):
            Y = self.Y[idcs] 
            return CustomTensorDataset(X=X,Y=Y,num_classes = self.num_classes, d=self.d,transform=None)

    def compute_embeddings(self):
        print("computing embeddings")
        data_dir  = self.data_conf['data_path']
        data_conf = self.data_conf
        # Define the model ID
        model_ID = data_conf['emb_model']
        device = "cuda" if torch.cuda.is_available() else "cpu"

        model = CLIPModel.from_pretrained(model_ID).to(device)
        processor = CLIPProcessor.from_pretrained(model_ID)

        # load training data
        train_mat = loadmat(os.path.join(data_dir, 'train.mat'))
        X_train = train_mat['X']
        y_train = train_mat['y'] - 1

        lst_train_images = []
        lst_train_labels = []

        for image, label in tqdm(zip(X_train, y_train), total=len(X_train)):
            image = Image.fromarray(image.reshape((3, 96, 96)), mode='RGB')
            image_p = processor(text=None, images = image,
                                    return_tensors="pt")["pixel_values"].to(device)
            embedding = model.get_image_features(image_p)
            lst_train_images.append(embedding.detach().cpu())
            lst_train_labels.append(label[0])

        # load test data
        test_mat = loadmat(os.path.join(data_dir, 'test.mat'))
        X_test = test_mat['X']
        y_test = test_mat['y'] - 1

        lst_val_images = [] 
        lst_val_labels = []

        for image, label in tqdm(zip(X_test, y_test), total=len(X_test)):
            image = Image.fromarray(image.reshape((3, 96, 96)), mode='RGB')
            image_p = processor(text=None, images = image,
                                    return_tensors="pt")["pixel_values"].to(device)
            embedding = model.get_image_features(image_p)
            lst_val_images.append(embedding.detach().cpu())
            lst_val_labels.append(label[0])
        

        all_images = lst_train_images + lst_val_images
        all_labels = lst_train_labels + lst_val_labels

        self.X = torch.stack(all_images).squeeze()
        self.Y = torch.Tensor(all_labels).long()
        self.X_val = torch.stack(lst_val_images)

        self.Y_val = torch.Tensor(lst_val_labels).long()
        self.idcs_std_train = np.arange(0,len(lst_train_images),1).astype(int)
        self.idcs_std_val   = np.arange(len(lst_train_images),len(all_images),1).astype(int)
        

    def build_dataset(self):

        # expecting path to tiny-imagenet-200 directory that contains train test and val folders
        data_dir  = self.data_conf['data_path']
        data_conf = self.data_conf
        emb_path = data_conf['emb_path']
        emb_model = data_conf['emb_model']
        ckpt_path = os.path.join(emb_path, f'stl_{emb_model}.pt') 

        if(data_conf['compute_emb']):
            self.compute_embeddings()
            

            ckpt_content = {
                            'X':self.X, 'Y' :self.Y , 
                            'idcs_std_train':self.idcs_std_train,
                            'idcs_std_val':self.idcs_std_val
                            }
            
            pp = os.path.join(*ckpt_path.split('/')[:-1])

            Path(pp).mkdir(parents=True, exist_ok=True)

            torch.save(ckpt_content, ckpt_path)

        else:
            ckpt_content = torch.load(ckpt_path)
            self.X = ckpt_content['X']
            self.Y = ckpt_content['Y']
            self.idcs_std_train = ckpt_content['idcs_std_train'] 
            self.idcs_std_val = ckpt_content['idcs_std_val']

        idcs_aug_val = list( np.random.choice(self.idcs_std_train, 1000, replace=False) ) 

        self.idcs_std_train = [ i for i in self.idcs_std_train if i not in idcs_aug_val ]
        self.idcs_std_val   = list(self.idcs_std_train) + idcs_aug_val

        
        self.X_val = self.X[self.idcs_std_val]
        self.Y_val = self.Y[self.idcs_std_val]
        
        self.X_train = self.X[self.idcs_std_train]
        self.Y_train = self.Y[self.idcs_std_train]
        
        #***** The labels for test set are not known
        # this is temporary fix, don't use test numbers
        self.X_test = self.X_val 
        self.Y_test = self.Y_val 
        
        
        self.ds_std_train = CustomTensorDataset(X=self.X_train,Y=self.Y_train, num_classes = self.num_classes,d=self.d,transform=None)
        self.ds_std_val   = CustomTensorDataset(X=self.X_val,Y=self.Y_val, num_classes = self.num_classes,d=self.d,transform=None)

    def get_test_datasets(self):
        X_ = self.X_test
        Y_ = self.Y_test
        return CustomTensorDataset(X=X_,Y=Y_, num_classes = self.num_classes,d=self.d,transform=None)
