
import os, sys
import gzip, pickle
import numpy as np
import torch
import torchvision
import argparse
from tqdm import tqdm

from dataset import CacheNPY, ToMesh, ProjectOnSphere, Shrec17

sys.path.append('../../..')

from utils.argparse_utils import *


import warnings
warnings.filterwarnings('ignore') # trimesh 

class KeepName:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, file_name):
        return file_name, self.transform(file_name)

def get_schrec17_data(raw_data_dir, split, augmentation, bw, perturbed=True, random_rotations=False, random_translation=0):
    # The 'perturbed' dataset already has random rotations as far as I understand, so... 

    # Load the dataset
    # Increasing `repeat` will generate more cached files
    transform = CacheNPY(prefix="b{}_".format(bw), repeat=augmentation, transform=torchvision.transforms.Compose(
        [
            ToMesh(random_rotations=random_rotations, random_translation=random_translation),
            ProjectOnSphere(bandwidth=bw)
        ]
    ))

    # I do not understand what this does
    # Hypothesis: the original labels are in this string form. This function converts them to simple integer labels
    def target_transform(x):
        classes = ['02691156', '02747177', '02773838', '02801938', '02808440', '02818832', '02828884', '02843684', '02871439', '02876657',
                    '02880940', '02924116', '02933112', '02942699', '02946921', '02954340', '02958343', '02992529', '03001627', '03046257',
                    '03085013', '03207941', '03211117', '03261776', '03325088', '03337140', '03467517', '03513137', '03593526', '03624134',
                    '03636649', '03642806', '03691459', '03710193', '03759954', '03761084', '03790512', '03797390', '03928116', '03938244',
                    '03948459', '03991062', '04004475', '04074963', '04090263', '04099429', '04225987', '04256520', '04330267', '04379243',
                    '04401088', '04460130', '04468005', '04530566', '04554684']
        return classes.index(x[0])
    
    if split == 'valid':
        split = 'val' # I just don't like using 'val' unless I reeeeally have to
    
    if split == 'test':
        transform = KeepName(transform)

    return Shrec17(raw_data_dir, split, perturbed=perturbed, download=False, transform=transform, target_transform=target_transform)


AUGMENTATION = 5 # unsure what this does, it doesn't seem to change the number of datapoints nor the quality of the augmentations


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--raw_data_dir', type=str)
    parser.add_argument('--output_dir', type=str)
    parser.add_argument('--bandwidth', type=int, default=30)
    parser.add_argument('--perturbed', type=str_to_bool, default=True)
    parser.add_argument('--random_rotations', type=str_to_bool, default=False)
    parser.add_argument('--random_translation', type=float, default=0.0)

    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    train_dataset_obj = get_schrec17_data(args.raw_data_dir, 'train', AUGMENTATION, args.bandwidth, perturbed=args.perturbed, random_rotations=args.random_rotations, random_translation=args.random_translation)
    valid_dataset_obj = get_schrec17_data(args.raw_data_dir, 'valid', AUGMENTATION, args.bandwidth, perturbed=args.perturbed, random_rotations=args.random_rotations, random_translation=args.random_translation)
    test_dataset_obj = get_schrec17_data(args.raw_data_dir, 'test', AUGMENTATION, args.bandwidth, perturbed=args.perturbed, random_rotations=args.random_rotations, random_translation=args.random_translation)

    train_images = []
    train_labels = []
    for i in tqdm(range(len(train_dataset_obj))):
        image, label = train_dataset_obj[i]
        n_channels = image.shape[0]
        train_images.append(image.reshape(n_channels, -1))
        train_labels.append(label)
    print('stacking arrays...', file=sys.stderr)
    train_images = np.stack(train_images, axis=0)
    train_labels = np.array(train_labels)
    # print('saving arrays...', file=sys.stderr)
    # output_filename_images = 'shrec17_on_grid-IMAGES-split=train-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.npy' % (args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # np.save(os.path.join(args.output_dir, output_filename_images), train_images)
    # output_filename_labels = 'shrec17_on_grid-LABELS-split=train-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.npy' % (args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # np.save(os.path.join(args.output_dir, output_filename_labels), train_labels)

    # n_splits = 5
    # block_size = train_labels.shape[0] // n_splits
    # for i in range(n_splits):
    #     train_dataset = {
    #         'images': train_images[i*block_size : (i+1)*block_size],
    #         'labels': train_labels[i*block_size : (i+1)*block_size]
    #     }
    #     output_filename = 'shrec17_on_grid-split=train_%d-test_labels=False-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.gz' % (i, args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    #     with gzip.open(os.path.join(args.output_dir, output_filename), 'wb') as f:
    #         pickle.dump(train_dataset, f)
    
    # train_dataset = {
    #     'images': train_images[n_splits*block_size :],
    #     'labels': train_labels[n_splits*block_size :]
    # }
    # output_filename = 'shrec17_on_grid-split=train_%d-test_labels=False-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.gz' % (n_splits, args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # with gzip.open(os.path.join(args.output_dir, output_filename), 'wb') as f:
    #     pickle.dump(train_dataset, f)

    valid_images = []
    valid_labels = []
    for i in tqdm(range(len(valid_dataset_obj))):
        image, label = valid_dataset_obj[i]
        n_channels = image.shape[0]
        valid_images.append(image.reshape(n_channels, -1))
        valid_labels.append(label)
    print('stacking arrays...', file=sys.stderr)
    valid_images = np.stack(valid_images, axis=0)
    valid_labels = np.array(valid_labels)
    # print('saving arrays...', file=sys.stderr)
    # output_filename_images = 'shrec17_on_grid-IMAGES-split=valid-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.npy' % (args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # np.save(os.path.join(args.output_dir, output_filename_images), valid_images)
    # output_filename_labels = 'shrec17_on_grid-LABELS-split=valid-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.npy' % (args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # np.save(os.path.join(args.output_dir, output_filename_labels), valid_labels)

    # n_splits = 5
    # block_size = valid_labels.shape[0] // n_splits
    # for i in range(n_splits):
    #     valid_dataset = {
    #         'images': valid_images[i*block_size : (i+1)*block_size],
    #         'labels': valid_labels[i*block_size : (i+1)*block_size]
    #     }
    #     output_filename = 'shrec17_on_grid-split=valid_%d-test_labels=False-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.gz' % (i, args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    #     with gzip.open(os.path.join(args.output_dir, output_filename), 'wb') as f:
    #         pickle.dump(valid_dataset, f)
    
    # valid_dataset = {
    #     'images': valid_images[n_splits*block_size :],
    #     'labels': valid_labels[n_splits*block_size :]
    # }
    # output_filename = 'shrec17_on_grid-split=valid_%d-test_labels=False-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.gz' % (n_splits, args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # with gzip.open(os.path.join(args.output_dir, output_filename), 'wb') as f:
    #     pickle.dump(valid_dataset, f)


    test_images = []
    test_labels = []
    test_ids = []
    for i in tqdm(range(len(test_dataset_obj))):
        (file_name, image), label = test_dataset_obj[i]
        n_channels = image.shape[0]
        test_images.append(image.reshape(n_channels, -1))
        test_labels.append(label)
        test_ids.append(file_name.split("/")[-1].split(".")[0])
    print('stacking arrays...', file=sys.stderr)
    test_images = np.stack(test_images, axis=0)
    test_labels = np.array(test_labels)
    test_ids = np.array(test_ids)
    # print('saving arrays...', file=sys.stderr)
    # output_filename_images = 'shrec17_on_grid-IMAGES-split=test-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.npy' % (args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # np.save(os.path.join(args.output_dir, output_filename_images), test_images)
    # output_filename_labels = 'shrec17_on_grid-LABELS-split=test-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.npy' % (args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # np.save(os.path.join(args.output_dir, output_filename_labels), test_labels)
    # output_filename_ids = 'shrec17_on_grid-IDS-split=test-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.npy' % (args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # np.save(os.path.join(args.output_dir, output_filename_ids), test_ids)

    # n_splits = 5
    # block_size = test_images.shape[0] // n_splits
    # for i in range(n_splits):
    #     test_dataset = {
    #         'images': test_images[i*block_size : (i+1)*block_size],
    #         'labels': None
    #     }
    #     output_filename = 'shrec17_on_grid-split=test_%d-test_labels=False-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.gz' % (i, args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    #     with gzip.open(os.path.join(args.output_dir, output_filename), 'wb') as f:
    #         pickle.dump(test_dataset, f)
    
    # test_dataset = {
    #     'images': test_images[n_splits*block_size :],
    #     'labels': None
    # }
    # output_filename = 'shrec17_on_grid-split=test_%d-test_labels=False-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.gz' % (n_splits, args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)
    # with gzip.open(os.path.join(args.output_dir, output_filename), 'wb') as f:
    #     pickle.dump(test_dataset, f)



    dataset = {
        'train': {
            'images': train_images,
            'labels': train_labels
        },
        'valid': {
            'images': valid_images,
            'labels': valid_labels
        },
        'test': {
            'images': test_images,
            'labels': test_labels,
            'ids': test_ids
        }
    }
    
    output_filename = 'shrec17_on_grid-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f.gz' % (args.bandwidth, args.perturbed, args.random_rotations, args.random_translation)

    with gzip.open(os.path.join(args.output_dir, output_filename), 'wb') as f:
        pickle.dump(dataset, f)
