import tensorflow as tf
import numpy as np
import os
import gc
gc.collect()


#tot_samps_from_each=138*512
#tot_samps = tot_samps_from_each*4
#samps_from_each = 512
#num_samples = samps_from_each*4
#num_tfrecods = tot_samps // num_samples
tfrecords_dir_train = './VR_train_data_seq40_pro/'
tfrecords_dir_test = './VR_test_data_seq40_pro/'


def parse_tfrecord_fn(example):
    feature_description =   {
        "image": tf.io.FixedLenFeature([64*64*3], tf.int64),
        "task": tf.io.FixedLenFeature([1], tf.int64),
        "reward": tf.io.FixedLenFeature([1], tf.float32),
        "action": tf.io.FixedLenFeature([38], tf.float32),
        "acceleration": tf.io.FixedLenFeature([3], tf.float32),
        "appendages": tf.io.FixedLenFeature([15], tf.float32),
        "joints": tf.io.FixedLenFeature([30], tf.float32),
        "zaxis": tf.io.FixedLenFeature([3], tf.float32),
        "touch": tf.io.FixedLenFeature([4], tf.float32),
        "joints_vel": tf.io.FixedLenFeature([30], tf.float32),
        "tendons_pos": tf.io.FixedLenFeature([8], tf.float32),
        "tendons_vel": tf.io.FixedLenFeature([8], tf.float32),
        "gyro": tf.io.FixedLenFeature([3], tf.float32),
        "velocity": tf.io.FixedLenFeature([3], tf.float32),
        "next_frame": tf.io.FixedLenFeature([64*64*3], tf.int64)
    }

    '''feature_description = {
        "image": tf.io.FixedLenFeature([64*64*3], tf.int64),
        "task": tf.io.FixedLenFeature([1], tf.int64),
        "reward": tf.io.FixedLenFeature([1], tf.float32),
        "action": tf.io.FixedLenFeature([38], tf.float32),
        "next_frame": tf.io.FixedLenFeature([64*64*3], tf.int64)
    }'''
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.cast(tf.reshape(example["image"],[64,64,3]),tf.float32)/255. 
    example["next_frame"] = tf.cast(tf.reshape(example["next_frame"],[64,64,3]),tf.float32)/255. 

    return example



def dataset_preprocess(train=True,val_size=8192,shuffle_buffer=None):
    if train:
        tfrs = [tfrecords_dir_train+f for f in os.listdir(tfrecords_dir_train)]
        raw_dataset = tf.data.TFRecordDataset(tfrs)
    else:
        tfrs = [tfrecords_dir_test+f for f in os.listdir(tfrecords_dir_test)]
        raw_dataset = tf.data.TFRecordDataset(tfrs)
    if shuffle_buffer is not None:
        raw_dataset = raw_dataset.shuffle(shuffle_buffer)

    if train:
        if val_size is not  None:
            val_dataset = raw_dataset.take(val_size)
            train_dataset = raw_dataset.skip(val_size)

            parsed_val_dataset = val_dataset.map(parse_tfrecord_fn).repeat(1)
            parsed_train_dataset = train_dataset.map(parse_tfrecord_fn).repeat()
            return parsed_train_dataset, parsed_val_dataset
        else:
            parsed_dataset = raw_dataset.map(parse_tfrecord_fn).repeat()
            return parsed_dataset
    else:           
        parsed_dataset = raw_dataset.map(parse_tfrecord_fn).repeat(1)
        return parsed_dataset


def load_and_transform_for_training(output_key,batch_size,shuffle_buffer=2048):
    train_set, val_set = dataset_preprocess(train=True)
    if shuffle_buffer is not None:
        train_set = train_set.shuffle(shuffle_buffer)

    def make_tuple(example):
        if output_key == 'all':
            return (example['image'],(example['task'],example['reward'],example['action'],example['zaxis']))
        elif output_key == 'all3':
            return (example['image'],(example['task'],example['reward'],example['zaxis']))
        elif output_key == 'proprio':
            #concatenate all proprio info:
            proprio_names = ['acceleration','appendages','touch','joints','zaxis','joints_vel','tendons_pos','tendons_vel','gyro','velocity']

            for p in proprio_names:
                if p == proprio_names[0]:
                    proprio = example[p]
                else:
                    proprio = tf.concat([proprio,example[p]],-1)
            return ((example['image'],proprio),example['action'])
        elif output_key == 'jprop':
            #concatenate all proprio info:
            proprio_names = ['acceleration','appendages','touch','joints','zaxis','joints_vel','tendons_pos','tendons_vel','gyro','velocity']

            for p in proprio_names:
                if p == proprio_names[0]:
                    proprio = example[p]
                else:
                    proprio = tf.concat([proprio,example[p]],-1)
            return (proprio,example['action'])
        else:
            return (example['image'],example[output_key])
    
    train_set = train_set.map(make_tuple)
    val_set = val_set.map(make_tuple)

    #TODO: any more shuffling here? and also something about prefetch, and batch?
    train_set = train_set.batch(batch_size)
    val_set = val_set.batch(batch_size)


    return train_set, val_set

def load_and_transform_for_CPCtraining(batch_size,frac_pos=.5,shuffle_buffer=2048):
    train_set, val_set = dataset_preprocess(train=True)
    if shuffle_buffer is not None:
        train_set = train_set.shuffle(shuffle_buffer)

    def make_tuple(example):
        return (example['image'],example['next_frame']) 
    
    train_set = train_set.map(make_tuple)
    val_set = val_set.map(make_tuple)

    train_set = train_set.batch(2)
    val_set = val_set.batch(2)

    def scramble_tuple(example1,example2):
        if tf.random.uniform([]) < frac_pos:
            return ((example1[0,:,:,:],example2[0,:,:,:]),1)
        else:
            return ((example1[0,:,:,:],example1[1,:,:,:]),0)

    #TODO: any more shuffling here? and also something about prefetch, and batch?
    train_set = train_set.map(scramble_tuple)
    val_set = val_set.map(scramble_tuple)
    train_set = train_set.batch(batch_size)
    val_set = val_set.batch(batch_size)

    return train_set, val_set


def load_and_transform_for_CPCtest(batch_size,frac_pos=.5,shuffle_buffer=None):
    test_set = dataset_preprocess(train=False,shuffle_buffer=shuffle_buffer)

    def make_tuple(example):
        return (example['image'],example['next_frame']) 
    
    test_set = test_set.map(make_tuple)

    test_set = test_set.batch(2)

    def scramble_tuple(example1,example2):
        if tf.random.uniform([]) < frac_pos:
            return ((example1[0,:,:,:],example2[0,:,:,:]),1)
        else:
            return ((example1[0,:,:,:],example1[1,:,:,:]),0)

    #TODO: any more shuffling here? and also something about prefetch, and batch?
    test_set = test_set.map(scramble_tuple)
    test_set = test_set.batch(batch_size)

    return test_set


def load_and_transform_for_test(output_key,batch_size,shuffle_buffer=None):
    test_set = dataset_preprocess(train=False,shuffle_buffer=shuffle_buffer) #this does shuffling

    def make_tuple(example):
        if output_key == 'all':
            return (example['image'],(example['task'],example['reward'],example['action'],example['zaxis']))
        elif output_key == 'all3':
            return (example['image'],(example['task'],example['reward'],example['zaxis']))
        elif output_key == 'all_orig':
            return (example['image'],example['task'],example['reward'])
        elif output_key == 'proprio':
            #concatenate all proprio info:
            proprio_names = ['acceleration','appendages','touch','joints','zaxis','joints_vel','tendons_pos','tendons_vel','gyro','velocity']

            for p in proprio_names:
                if p == proprio_names[0]:
                    proprio = example[p]
                else:
                    proprio = tf.concat([proprio,example[p]],-1)
            
            return ((example['image'],proprio),example['action']) 
        elif output_key == 'jprop':
            #concatenate all proprio info:
            proprio_names = ['acceleration','appendages','touch','joints','zaxis','joints_vel','tendons_pos','tendons_vel','gyro','velocity']

            for p in proprio_names:
                if p == proprio_names[0]:
                    proprio = example[p]
                else:
                    proprio = tf.concat([proprio,example[p]],-1)
            return (proprio,example['action'])
        elif output_key == 'all_prop':
            #concatenate all proprio info:
            proprio_names = ['acceleration','appendages','touch','joints','zaxis','joints_vel','tendons_pos','tendons_vel','gyro','velocity']

            for p in proprio_names:
                if p == proprio_names[0]:
                    proprio = example[p]
                else:
                    proprio = tf.concat([proprio,example[p]],-1)
            return ((example['image'],proprio),example['task'],example['reward'],example['action'],example['zaxis'])
        else:
            return (example['image'],example[output_key])
    
    test_set = test_set.map(make_tuple)

    #TODO: any more shuffling here? and also something about prefetch, and batch?
    test_set = test_set.batch(batch_size)


    return test_set


def check_val_data():
    train, val = load_and_transform_for_training('image',128)
    print(512*4, 'is how many examples per file')
    ii = -1
    for sample in train.take(5000):
        print(sample[1].numpy().shape); b=d
        ii+=1
        if ii == 0:
            print(sample[0].numpy()[0,3:6,8:12,1])
            print(sample[1].numpy()[0])
    print(ii)
    print(sample[0].numpy()[0,3:6,8:12,1])
    print(sample[1].numpy()[0])

    ii = -1
    for sample in val.take(5000):
        ii+=1
        if True: #ii == 0:
            #print(sample[0].numpy()[0,3:6,8:12,1])
            print(sample[1].numpy()[:])
    print(ii)
    print(sample[0].numpy()[0,3:6,8:12,1])
    print(sample[1].numpy()[0])





