import tarfile
import yaml
import os
from utils import mkdir_p
import numpy as np
import pandas as pd
from PIL import Image
import wget
import gzip
import shutil

config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']
##prepare
if os.path.exists(os.path.join(src_dir, 'memguard')):
    shutil.move(os.path.join(src_dir, 'memguard'), os.path.join(root_dir, 'memguard'))

if not os.path.exists(os.path.join(root_dir, 'tmp')):
    os.makedirs(os.path.join(root_dir, 'tmp'))
###assumeing two tar files dataset_purcahse.tgz and dataset_texas.tgz are saved in root_dir/tmp.
####prepare purchase dataset
if not os.path.isfile(os.path.join(root_dir, 'tmp', 'dataset_purchase.tgz')):
    print("Dowloading purchase dataset...")
    wget.download("https://www.comp.nus.edu.sg/~reza/files/dataset_purchase.tgz",
                  os.path.join(root_dir, 'tmp', 'dataset_purchase.tgz'))
    print('Dataset Dowloaded')

if not os.path.isfile(os.path.join(root_dir, 'tmp', 'dataset_texas.tgz')):
    print("Dowloading texas dataset...")
    wget.download("https://www.comp.nus.edu.sg/~reza/files/dataset_texas.tgz",
                  os.path.join(root_dir, 'tmp', 'dataset_texas.tgz'))
    print('Dataset Dowloaded')

if not os.path.isfile(os.path.join(root_dir, 'tmp', 'cifar-100-python.tar.gz')):
    print("Dowloading cifar100 dataset...")
    wget.download("http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz",
                  os.path.join(root_dir, 'tmp', 'cifar-100-python.tar.gz'))
    print('Dataset Dowloaded')
if not os.path.exists(os.path.join(root_dir, 'cifar100', 'data')):
    print("Prepare CIFAR100 dataset")
    tar = tarfile.open(os.path.join(root_dir, 'tmp', 'cifar-100-python.tar.gz'))
    tar.extractall(path=os.path.join(root_dir, 'cifar100'))
    os.rename(os.path.join(root_dir, 'cifar100', 'cifar-100-python'), os.path.join(root_dir, 'cifar100', 'data'))
if not os.path.isfile(os.path.join(root_dir, 'tmp', 'tiny-imagenet-200.zip')):
    print("Dowloading TinyImageNet dataset...")
    wget.download("http://cs231n.stanford.edu/tiny-imagenet-200.zip",
                  os.path.join(root_dir, 'tmp', 'tiny-imagenet-200.zip'))

print("Prepare Purchase100 dataset")
# tar = tarfile.open(os.path.join(root_dir, 'tmp', 'dataset_purchase.tgz'), 'r:gz')
# tar.extractall(path=os.path.join(root_dir, 'tmp'))
# with tarfile.open(os.path.join(root_dir, 'tmp', 'dataset_purchase.tgz'), 'r:gz') as tar:
#    tar.extractall(path=os.path.join(root_dir, 'tmp'))
# data_set =np.genfromtxt(os.path.join(root_dir, 'tmp', 'dataset_purchase'), delimiter=',')
#
# X = data_set[:,1:].astype(np.float64)
# Y = (data_set[:,0]).astype(np.int32)-1
#
# DATASET_PATH = os.path.join(root_dir, 'purchase', 'data')
# if not os.path.exists(DATASET_PATH):
#    mkdir_p(DATASET_PATH)
#
# np.save(os.path.join(DATASET_PATH, 'X.npy'), X)
# np.save(os.path.join(DATASET_PATH,'Y.npy'), Y)

# print("Prepare Texas100 dataset")
#####prepare texas dataset
# print(os.path.join(root_dir, 'tmp', 'dataset_texas.tgz'))
# tar = tarfile.open(os.path.join(root_dir, 'tmp', 'dataset_texas.tgz'))
# tar.extractall(path=os.path.join(root_dir, 'tmp'))
#
# data_set_features =np.genfromtxt(os.path.join(root_dir, 'tmp', 'texas/100/feats'), delimiter=',')
# data_set_label =np.genfromtxt(os.path.join(root_dir, 'tmp', 'texas/100/labels'), delimiter=',')
#
# X =data_set_features.astype(np.float64)
# Y = data_set_label.astype(np.int32)-1
#
# DATASET_PATH = os.path.join(root_dir, 'texas', 'data')
# if not os.path.exists(DATASET_PATH):
#    mkdir_p(DATASET_PATH)

######save dataset in numpy format as loading by genfromtxt takes several minutes when loading.
# np.save(os.path.join(DATASET_PATH, 'feats.npy'), X)
# np.save(os.path.join(DATASET_PATH, 'labels.npy'), Y)

print("Prepare UTK-Face age dataset")
####prepare texas dataset
#print(os.path.join(root_dir, 'tmp', 'age_gender.gz'))
#
#pd00 = pd.read_csv(os.path.join(root_dir, 'tmp', 'age_gender.gz'), compression='gzip')
#age_bins = [0, 10, 15, 20, 25, 30, 40, 50, 60, 120]
#age_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8]
#pd00['age_bins'] = pd.cut(pd00.age, bins=age_bins, labels=age_labels)
#X = pd00.pixels.apply(lambda x: np.array(x.split(" "), dtype=float))
#X = np.stack(X)
#X = X / 255.0
#X = X.astype('float32').reshape(X.shape[0], 1, 48, 48)
#print(np.max(X), np.min(X))
#y = pd00['age_bins'].to_numpy()
#y = y.reshape(y.shape[0], 1)
## g = pd00['ethnicity'].to_numpy()  # .values.reshape(-1, 1)
## g = g.reshape(g.shape[0], 1)
#Y = y  # np.concatenate((y, g), axis=1)
#
#X = X.astype(np.float64)
#Y = Y.astype(np.int32)
#print(X.shape, Y.shape)
#
#DATASET_PATH = os.path.join(root_dir, 'utk', 'data')
#if not os.path.exists(DATASET_PATH):
#    mkdir_p(DATASET_PATH)
#
#######save dataset in numpy format as loading by genfromtxt takes several minutes when loading.
#np.save(os.path.join(DATASET_PATH, 'feats.npy'), X)
#np.save(os.path.join(DATASET_PATH, 'labels.npy'), Y)

print("Prepare UTK-Face age dataset (200x200)")
## folder_path
#folder_path = os.path.join(root_dir, 'tmp', 'utkface')
## data list
#Xs = []
#Ages = []
##Genders = []
##Races = []
## Define age bin boundaries
#bins = [10, 15, 20, 25, 30, 40, 50, 60, 75]
## go through all files
#for root, dirs, files in os.walk(folder_path):
#    for file_name in files:
#        # get attribute
#        attributes = file_name.split('_')
#        if len(attributes) != 4:
#            print("error image name")
#            continue
#        age = int(attributes[0])
#        gender = int(attributes[1])
#        race = int(attributes[2])
#        # Open the image file
#        image = Image.open(os.path.join(root, file_name))
#        # Convert the image to a numpy array
#        image_array = np.array(image)
#        # Rearrange dimensions to have RGB channels at the first dimension
#        image_array = np.moveaxis(image_array, -1, 0)
#        # Print the shape of the numpy array
#        print("Shape of the numpy array:", image_array.shape)
#
#        # add to list
#        Xs.append(image_array)
#        Ages.append(age)
#        #Genders.append(gender)
#        #Races.append(race)
## define dataset path
#DATASET_PATH = os.path.join(root_dir, 'utkface', 'data')
#if not os.path.exists(DATASET_PATH):
#    mkdir_p(DATASET_PATH)
## save data to numpy
#X = np.stack(Xs).astype(np.float32)
#print('UTKFace: X is done')
#np.save(os.path.join(DATASET_PATH, 'feats.npy'), X)
#xshape = X.shape
#X, Xs = None, None
#Y = np.stack(Ages).astype(np.int32)
## Cut ages into bins
#Y = np.digitize(Y, bins, right=True)
#print('UTKFace: Y is done')
#np.save(os.path.join(DATASET_PATH, 'labels.npy'), Y)
#print(xshape, Y.shape)
## print(X[0])
#print(Y[0:10])


print("Prepare TinyImageNet dataset")
folder_path = os.path.join(root_dir, 'tmp', 'tiny-imagenet-200')
# read class mapping file
word_to_int = {}
next_int = 0
with open(f'{folder_path}/wnids.txt', 'r') as file:
    for line in file:
        # Split the line into words
        words = line.split()
        # Map each word to a unique integer
        for word in words:
            if word not in word_to_int:
                word_to_int[word] = next_int
                next_int += 1
# Print the mapping
print(word_to_int)
# build data path
DATASET_PATH = os.path.join(root_dir, 'tinyimagenet', 'data')
if not os.path.exists(DATASET_PATH):
    mkdir_p(DATASET_PATH)
# Read Train Data
X=[]
Y=[]
train_folder_path = os.path.join(folder_path, 'train')
for key in word_to_int:
    # print(key, word_to_int[key])
    train_img_path = os.path.join(train_folder_path, key, 'images')
    # go through all files
    for root, dirs, files in os.walk(train_img_path):
        for file_name in files:
            # Open the image file
            image = Image.open(os.path.join(train_img_path, file_name))
            if image.mode == 'L':
                image = image.convert('RGB')
            # Convert the image to a numpy array
            image_array = np.array(image)
            # Rearrange dimensions to have RGB channels at the first dimension
            image_array = np.moveaxis(image_array, -1, 0)
            # Print the shape of the numpy array
            # print("Shape of the numpy array:", image_array.shape)
            label = word_to_int[key]
            # add to list
            X.append(image_array)
            Y.append(label)

X=np.stack(X).astype(np.float32)
Y=np.stack(Y).astype(np.int32)
print("train set:", X.shape, Y.shape)
np.savez(f'{DATASET_PATH}/train.npz', x=X, y=Y)
# Read Test Data
val_folder_path = os.path.join(folder_path, 'val')
val_img_path = os.path.join(val_folder_path, 'images')
X=[]
Y=[]
with open(f'{val_folder_path}/val_annotations.txt', 'r') as file:
    for line in file:
        # Split the line into words
        words = line.split()
        file_name, key = words[0], words[1]
        # Open the image file
        image = Image.open(os.path.join(val_img_path, file_name))
        if image.mode == 'L':
            image = image.convert('RGB')
        # Convert the image to a numpy array
        image_array = np.array(image)
        # Rearrange dimensions to have RGB channels at the first dimension
        image_array = np.moveaxis(image_array, -1, 0)
        # Print the shape of the numpy array
        # print("Shape of the numpy array:", image_array.shape)
        label = word_to_int[key]
        # add to list
        X.append(image_array)
        Y.append(label)

X = np.stack(X).astype(np.float32)
Y = np.stack(Y).astype(np.int32)
print("val set:", X.shape, Y.shape)
np.savez(f'{DATASET_PATH}/test.npz', x=X, y=Y)

