# from argparse import _SUPPRESS_T
import os
import glob
import random
import csv
import numpy as np
import shutil
from distutils.dir_util import copy_tree
import cv2
from tqdm import tqdm
from PIL import Image


class SplitGenerator(object):
    
    def __init__(self, data_dir, dest_dir, few_data_dir, split_dir=None, supercls_raw=True, supercls_split=False, verbose=None):
        
        self.data_dir = data_dir
        self.dest_dir = dest_dir
        self.few_data_dir = few_data_dir
        self.split_dir = split_dir
        self.supercls_raw = supercls_raw
        self.supercls_split = supercls_split
        self.verbose = verbose

    def split_train_val_test(self, nclass_train=64, nclass_val=16, nclass_test=20, save_split_npy=False, csv_save_form=None, max_num=100, option=None):
        
        # get the list of all the files
        classdir_all = glob.glob(self.data_dir + '/*'*(self.supercls_raw + self.supercls_split))
        # print(len(classdir_all))
        # filter out the class with not enough images
        if option == 'fungi':
            filter_classdir_all = []
            for classdir in classdir_all:
                image_num = len(glob.glob(classdir+'/*'))
                if image_num >= 150:
                    filter_classdir_all.append(classdir)
            classdir_all = filter_classdir_all
        elif option == 'birds':
            filter_classdir_all = []
            for classdir in classdir_all:
                image_num = len(glob.glob(classdir+'/*'))
                if image_num >= 60:
                    filter_classdir_all.append(classdir)
            classdir_all = filter_classdir_all
        elif option == 'texture':
            filter_classdir_all = []
            for classdir in classdir_all:
                image_num = len(glob.glob(classdir+'/*'))
                if image_num >= 120:
                    filter_classdir_all.append(classdir)
            classdir_all = filter_classdir_all
        elif option == 'aircraft':
            filter_classdir_all = []
            for classdir in classdir_all:
                image_num = len(glob.glob(classdir+'/*'))
                if image_num >= 100:
                    filter_classdir_all.append(classdir)
            classdir_all = filter_classdir_all
        


        # random selects the train, val and test set
        self.classdir_train = random.sample(classdir_all, nclass_train)
        classdir_excl_train = set(classdir_all) - set(self.classdir_train)
        
        self.classdir_val = random.sample(list(classdir_excl_train), nclass_val)
        classdir_excl_train_val = set(classdir_excl_train) - set(self.classdir_val)
        self.classdir_test = random.sample(list(classdir_excl_train_val), nclass_test)
        
        # print(len(self.classdir_train))
        # print(len(self.classdir_val))
        # print(len(self.classdir_test))

        self.few_train = []
        self.few_val = []
        self.few_test = []
        
        # choose the max num of each class
        if self.verbose == 'cu_birds':
            max_num = 60
        elif self.verbose == 'texture':
            max_num = 120
        elif self.verbose == 'aircraft':
            max_num = 100
        elif self.verbose == 'fungi':
            max_num = 150
        elif self.verbose == 'quickdraft':
            max_num = 1000
        else:
            max_num = 100
        
        if os.path.exists(self.few_data_dir):
            shutil.rmtree(self.few_data_dir)
        os.makedirs(self.few_data_dir)


        # for training dataset
        for idx, classdir in (tqdm(enumerate(self.classdir_train), desc='Generating {}'.format(self.verbose),
                    total=len(self.classdir_train)) if self.verbose is not None else enumerate(self.classdir_train)):

            few_class_dir = os.path.join(self.few_data_dir, classdir.split('/')[-1])
            self.few_train.append(few_class_dir)
            if os.path.exists(few_class_dir):
                shutil.rmtree(few_class_dir)
            os.makedirs(few_class_dir)

            # list all imgs in the class
            img_path_list = glob.glob(classdir + '/*')
            img_path_list = random.sample(img_path_list, max_num)

            for img_path in img_path_list:

                image = Image.open(img_path)
                few_img_path = os.path.join(few_class_dir, img_path.split('/')[-1])

                image.save(few_img_path)
        
        # for valuation dataset
        for idx, classdir in (tqdm(enumerate(self.classdir_val), desc='Generating {}'.format(self.verbose),
                    total=len(self.classdir_val)) if self.verbose is not None else enumerate(self.classdir_val)):

            few_class_dir = os.path.join(self.few_data_dir, classdir.split('/')[-1])
            self.few_val.append(few_class_dir)
            if os.path.exists(few_class_dir):
                shutil.rmtree(few_class_dir)
            os.makedirs(few_class_dir)

            # list all imgs in the class
            img_path_list = glob.glob(classdir + '/*')
            img_path_list = random.sample(img_path_list, max_num)

            for img_path in img_path_list:

                image = Image.open(img_path)
                few_img_path = os.path.join(few_class_dir, img_path.split('/')[-1])

                image.save(few_img_path)
                
        
        # for valuation dataset
        for idx, classdir in (tqdm(enumerate(self.classdir_test), desc='Generating {}'.format(self.verbose),
                    total=len(self.classdir_test)) if self.verbose is not None else enumerate(self.classdir_test)):

            few_class_dir = os.path.join(self.few_data_dir, classdir.split('/')[-1])
            self.few_test.append(few_class_dir)
            if os.path.exists(few_class_dir):
                shutil.rmtree(few_class_dir)
            os.makedirs(few_class_dir)

            # list all imgs in the class
            img_path_list = glob.glob(classdir + '/*')
            img_path_list = random.sample(img_path_list, max_num)

            for img_path in img_path_list:

                image = Image.open(img_path)
                few_img_path = os.path.join(few_class_dir, img_path.split('/')[-1])

                image.save(few_img_path)

        # save the npy file
        if save_split_npy:
            if os.path.exists(self.dest_dir):
                # clean the previous split
                shutil.rmtree(self.dest_dir)
            os.makedirs(self.dest_dir)

            np.save(os.path.join(self.dest_dir, 'metatrain.npy'), self.few_train)
            np.save(os.path.join(self.dest_dir, 'metaval.npy'), self.few_val)
            np.save(os.path.join(self.dest_dir, 'metatest.npy'), self.few_test)
        