import numpy as np
import torch 
from torch.utils.data import Dataset
import torchvision.transforms as tfs
import cv2
from PIL import Image
import pandas as pd
import glob
class stroke(Dataset):
    def __init__(self, 
                 csv_path, 
                 image_size=32,
                 shuffle=True,
                 seed=123,
                 verbose=True,
                 train_cols=['label'],
                 mode='train'):
        
    
        # load data from csv
        self.df = pd.read_csv(csv_path) 
        self._num_images = len(self.df)
            
        # shuffle data
        if shuffle:
            data_index = list(range(self._num_images))
            np.random.seed(seed)
            np.random.shuffle(data_index)
            self.df = self.df.iloc[data_index]
        
        

        self.select_cols = ['label']  # this var determines the number of classes
        self.value_counts_dict = self.df[self.select_cols[0]].value_counts().to_dict()
        print(self.value_counts_dict)
        
        self.mode = mode
        self.image_size = image_size
        
        self._images_list = self.df.iloc[:,0].values.tolist()
        self._labels_list = self.df[train_cols].values.tolist()
                
    @property        
    def class_counts(self):
        return self.value_counts_dict
    
    @property
    def imbalance_ratio(self):
        return self.imratio

    @property
    def num_classes(self):
        return len(self.select_cols)
       
    @property  
    def data_size(self):
        return self._num_images 
    
    
    def __len__(self):
        return self._num_images
    
    def __getitem__(self, idx):
        uid = self._images_list[idx]
        if self._labels_list[idx][0] in [2]:
          label = np.array([0.0])  
        elif self._labels_list[idx][0] in [1]:
          label = np.array([1.0])  
        else:
          return np.zeros([1,1,1,1]), np.zeros([1])
        for i in range(10,14):
          for j in range(1,81):
            path = glob.glob("/dual_data/not_backed_up/dixzhu/stroke/datasets_png_128x128/"+str(uid)+"/"+str(i)+"/*"+str(i)+'-'+str(j)+"*.png")
            if len(path) == 0:
              print('uid:' + str(uid))
              print('i:' + str(i))
              print('j:' + str(j))
              print(path)
              return np.zeros([1,1,1,1]), np.zeros([1])
            image = cv2.imread(path[0])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
            image = cv2.resize(image, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR).astype(np.float32)
            
            image = np.expand_dims(image,axis=0)
            image = image/255.0
            if i == 10 and j ==1:
              tmp = image
            else:
              tmp = np.concatenate([tmp,image],axis=0)
        image = tmp
        image = np.expand_dims(image,axis=0)
        #print(image.shape)
        print('label: '+str(label))
        #exit()
        return image, label

    def get_labels(self):
        return np.array(self._labels_list).reshape(-1)

if __name__ == '__main__':
    root = '/dual_data/not_backed_up/dixzhu/stroke/'
    traindSet = stroke(csv_path=root+'AI_data_codes.csv', image_size=128, mode='train')    
    trainloader =  torch.utils.data.DataLoader(traindSet, batch_size=32, num_workers=2, drop_last=False, shuffle=False)
    # convert jpgs to binary file.
    trX = []
    trY = []
    for idx, data in enumerate(traindSet):
      train_data, train_label = data
      if train_data.shape[1] != 80*4:
        continue
      print(idx)
      trX.append(train_data)
      trY.append(train_label)
    trX = np.concatenate(trX, axis=0)
    trY = np.concatenate(trY, axis=0)
    np.save('/home/dixzhu/data/stroke_80*4_X',trX)
    np.save('/home/dixzhu/data/stroke_80*4_Y',trY)
    print(trX.shape)
    print(trY.shape)
    print(len(trY[trY==0]))
    print(len(trY[trY==1]))
    
