import os
import json
import pprint
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow as tf
import functools
from acme.adders import reverb as adders
import reverb
import tree
import numpy as np


def transform_reward(rewards):
                new_rewards = np.zeros((len(rewards))); #print(len(rewards))
                reward_scl = 0
                for i in range(1,len(rewards)+1):
                    if rewards[-i] > 1:
                        new_rewards[-i] = 1
                        reward_scl = .75
                    else:
                        new_rewards[-i] = rewards[-i] + 1*reward_scl
                        reward_scl = .75*reward_scl
                        if reward_scl < .001:
                            reward_scl = 0
                return new_rewards


def _build_sequence_example(sequences):
  """Convert raw sequences into a Reverb sequence sample."""
  data = adders.Step(
      observation=sequences['observation'],
      action=sequences['action'],
      reward=sequences['reward'],
      discount=sequences['discount'],
      start_of_episode=(),
      extras=())

  info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
                           probability=tf.constant(1.0, tf.float64),
                           table_size=tf.constant(0, tf.int64),
                           priority=tf.constant(1.0, tf.float64))
  return reverb.ReplaySample(info=info, data=data)

def _padded_batch(example_ds, batch_size, shapes, drop_remainder=False):
  """Batch data while handling unequal lengths."""
  padded_shapes = {}
  padded_shapes['observation'] = {}
  for k, v in shapes.items():
    if 'observation' in k:
      padded_shapes['observation'][
          k.replace('observation/', '')] = (-1,) + v
    else:
      padded_shapes[k] = (-1,) + v

  padded_shapes['length'] = ()

  return example_ds.padded_batch(batch_size,
                                 padded_shapes=padded_shapes,
                                 drop_remainder=drop_remainder)

def _parse_seq_tf_example(example, uint8_features, shapes):
  """Parse tf.Example containing one or two episode steps."""
  def to_feature(key, shape):
    if key in uint8_features:
      return tf.io.FixedLenSequenceFeature(
          shape=[], dtype=tf.string, allow_missing=True)
    else:
      return tf.io.FixedLenSequenceFeature(
          shape=shape, dtype=tf.float32, allow_missing=True)

  feature_map = {}
  for k, v in shapes.items():
    feature_map[k] = to_feature(k, v)

  parsed = tf.io.parse_single_example(example, features=feature_map)

  observation = {}
  restructured = {}
  for k in parsed.keys():
    if 'observation' not in k:
      restructured[k] = parsed[k]
      continue

    if k in uint8_features:
      observation[k.replace('observation/', '')] = tf.reshape(
          tf.io.decode_raw(parsed[k], out_type=tf.uint8), (-1,) + shapes[k])
    else:
      observation[k.replace('observation/', '')] = parsed[k]

  restructured['observation'] = observation

  restructured['length'] = tf.shape(restructured['action'])[0]

  return restructured



def image_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[value.encode()])
    )


def bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))


def float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_feature_list(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def float_feature_list(value):
    """Returns a list of float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_example(image, output_dict):
    feature = {
        "image": int64_feature_list(image),
        "task": int64_feature(output_dict['task']),
        "reward": float_feature(output_dict['reward']),
        "action": float_feature_list(output_dict['action']),
        "acceleration": float_feature_list(output_dict['acceleration']),
        "appendages": float_feature_list(output_dict['appendages']),
        "touch": float_feature_list(output_dict['touch']),
        "joints": float_feature_list(output_dict['joints']),
        "zaxis": float_feature_list(output_dict['zaxis']),
        "joints_vel": float_feature_list(output_dict['joints_vel']),
        "tendons_pos": float_feature_list(output_dict['tendons_pos']),
        "tendons_vel": float_feature_list(output_dict['tendons_vel']),
        "gyro": float_feature_list(output_dict['gyro']),
        "velocity": float_feature_list(output_dict['velocity']),
        "next_frame": int64_feature_list(output_dict['next_frame']),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))


def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([64*64*3], tf.int64),
        "task": tf.io.FixedLenFeature([1], tf.int64),
        "reward": tf.io.FixedLenFeature([1], tf.float32),
        "action": tf.io.FixedLenFeature([38], tf.float32),
        "appendages": tf.io.FixedLenFeature([15], tf.float32),
        "joints": tf.io.FixedLenFeature([30], tf.float32),
        "joints_vel": tf.io.FixedLenFeature([30], tf.float32),
        "tendons_vel": tf.io.FixedLenFeature([8], tf.float32),
        "tendons_pos": tf.io.FixedLenFeature([8], tf.float32),
        "acceleration": tf.io.FixedLenFeature([3], tf.float32),
        "velocity": tf.io.FixedLenFeature([3], tf.float32),
        "gyro": tf.io.FixedLenFeature([3], tf.float32),
        "zaxis": tf.io.FixedLenFeature([3], tf.float32),
        "touch": tf.io.FixedLenFeature([4], tf.float32),
        "next_frame": tf.io.FixedLenFeature([64*64*3], tf.int64)
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.cast(tf.reshape(example["image"],[64,64,3]),tf.float32)/255. 
    example["next_frame"] = tf.cast(tf.reshape(example["next_frame"],[64,64,3]),tf.float32)/255. 

    return example





num_threads = 1
shapes = {
        'observation/walker/joints_pos': (30,),
        'observation/walker/joints_vel': (30,),
        'observation/walker/tendons_pos': (8,),
        'observation/walker/tendons_vel': (8,),
        'observation/walker/appendages_pos': (15,),
        'observation/walker/world_zaxis': (3,),
        'observation/walker/sensors_accelerometer': (3,),
        'observation/walker/sensors_velocimeter': (3,),
        'observation/walker/sensors_gyro': (3,),
        'observation/walker/sensors_touch': (4,),
        'observation/walker/egocentric_camera': (64, 64, 3),
        'action': (38,),
        'discount': (),
        'reward': (),
        'step_type': ()
    }
uint8_features = set(['observation/walker/egocentric_camera'])
num_shards = 1
shuffle_buffer_size=1

tfrecords_dir = '/home/grace/cluster_code/virtual_rodent/RLunplugged/test40_tfrecords/' #'/nfs/nhome/live/glindsay/cluster_code/virtual_rodent/tfrecords/'

batch_size = 1
root_path = '/home/grace/cluster_code/virtual_rodent/RLunplugged/data/'
root_path2 = '/home/grace/cluster_code/virtual_rodent/RLunplugged/seq40data/'
#filenames = ['/home/grace/cluster_code/virtual_rodent/RLunplugged/data/dm_locomotion_rodent_gaps_seq2_train-00099-of-00100']
#all_filenames = [[root_path+'dm_locomotion_rodent_bowl_escape_seq2_train-00000-of-00100',root_path+'dm_locomotion_rodent_bowl_escape_seq2_train-00001-of-00100',root_path+'dm_locomotion_rodent_bowl_escape_seq2_train-00002-of-00100'],[root_path+'dm_locomotion_rodent_gaps_seq2_train-00000-of-00100',root_path+'dm_locomotion_rodent_gaps_seq2_train-00001-of-00100',root_path+'dm_locomotion_rodent_gaps_seq2_train-00002-of-00100',root_path+'dm_locomotion_rodent_gaps_seq2_train-00003-of-00100',root_path+'dm_locomotion_rodent_gaps_seq2_train-00004-of-00100'],[root_path+'dm_locomotion_rodent_two_touch_seq2_train-00001-of-00100',root_path+'dm_locomotion_rodent_two_touch_seq2_train-00000-of-00100'],[root_path+'dm_locomotion_rodent_mazes_seq2_train-00001-of-00100',root_path+'dm_locomotion_rodent_mazes_seq2_train-00000-of-00100',root_path+'dm_locomotion_rodent_mazes_seq2_train-00002-of-00100']]
all_filenames = [[root_path+'dm_locomotion_rodent_bowl_escape_seq2_train-00003-of-00100'],[root_path+'dm_locomotion_rodent_gaps_seq2_train-00005-of-00100'],[root_path2+'dm_locomotion_rodent_two_touch_seq40_train-00002-of-00100'],[root_path2+'dm_locomotion_rodent_mazes_seq40_train-00003-of-00100']]


tot_samps_from_each= 2048
tot_samps = tot_samps_from_each*len(all_filenames)
samps_from_each = 512
num_samples = samps_from_each*len(all_filenames)
print('Number of tasks: ',len(all_filenames))

num_tfrecods = tot_samps // num_samples
if tot_samps % num_samples:
    num_tfrecods += 1  # add one record if there are any remaining samples

if not os.path.exists(tfrecords_dir):
    os.makedirs(tfrecords_dir)  # creating TFRecords output folder

def map_func(example):
    example = _parse_seq_tf_example(example, uint8_features, shapes)
    return example

for tfrec_num in range(num_tfrecods):
    
    with tf.io.TFRecordWriter(
          tfrecords_dir + "/test40file_%.2i-%i.tfrec" % (tfrec_num, num_samples)
              ) as writer:


      for filenames in all_filenames:
        file_ds = tf.data.Dataset.from_tensor_slices(filenames)
        file_ds=file_ds.repeat(1)

        example_ds = file_ds.interleave(functools.partial(tf.data.TFRecordDataset, compression_type='GZIP')) #,

        example_ds = example_ds.map(map_func, num_parallel_calls=num_threads)

        example_ds = _padded_batch(
        example_ds, batch_size, shapes, drop_remainder=True)

        example_ds = example_ds.map(
        _build_sequence_example,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

        example_ds = example_ds.prefetch(tf.data.experimental.AUTOTUNE)


        if 'bowl' in filenames[0]:
             task_name = 0
             print('bowl')
        elif 'gaps' in filenames[0]:
             task_name = 1
             print('gaps')
        elif 'maze' in filenames[0]:
             task_name = 2
             print('maze')
        elif 'touch' in filenames[0]:
             task_name = 3
             print('touch')

        iii=-1
        for sample in example_ds.take(samps_from_each*(tfrec_num+1)):
          iii+=1
          if iii>= samps_from_each*tfrec_num:
            image = sample.data.observation['walker/egocentric_camera'].numpy()[0,0,:,:,:].flatten().tolist()
            if task_name >=2:
                rewards = sample.data.reward.numpy()[0,:] #first batch, all frames
                rewards = transform_reward(rewards); #print(task_name,rewards)
            else:
                rewards = sample.data.reward.numpy()[0,:]

            output_dict =  {'next_frame': sample.data.observation['walker/egocentric_camera'].numpy()[0,1,:,:,:].flatten().tolist(), 'action': sample.data.action.numpy()[0,0,:], 'reward': rewards[0], 'task': task_name, 'appendages':sample.data.observation['walker/appendages_pos'].numpy()[0,0,:],'joints':sample.data.observation['walker/joints_pos'].numpy()[0,0,:],'joints_vel':sample.data.observation['walker/joints_vel'].numpy()[0,0,:],'tendons_pos':sample.data.observation['walker/tendons_pos'].numpy()[0,0,:],'tendons_vel':sample.data.observation['walker/tendons_vel'].numpy()[0,0,:],'zaxis':sample.data.observation['walker/world_zaxis'].numpy()[0,0,:],'acceleration':sample.data.observation['walker/sensors_accelerometer'].numpy()[0,0,:],'velocity':sample.data.observation['walker/sensors_velocimeter'].numpy()[0,0,:],'gyro':sample.data.observation['walker/sensors_gyro'].numpy()[0,0,:],'touch':sample.data.observation['walker/sensors_touch'].numpy()[0,0,:]}
            #print(output_dict.keys()); break
            

            example = create_example(image, output_dict) 
            writer.write(example.SerializeToString())
        print(iii); 


raw_dataset = tf.data.TFRecordDataset(f"{tfrecords_dir}test40file_01-{num_samples}.tfrec")


parsed_dataset = raw_dataset.map(parse_tfrecord_fn)
#parsed_dataset = parsed_dataset.repeat(1).shuffle(512*4)
ii=0
for features in parsed_dataset.take(1):
        #print(type(features)); #b=d
        ii+=1

        for key in features.keys():
            if True: # key != "image" and key != "next_frame":
                print(f"{key}: {features[key]}")

        print(f"Image shape: {features['image'].shape}")
        print(features); b=d
        plt.figure(figsize=(7, 7))
        plt.imshow(features["image"].numpy()); plt.yticks([]); plt.xticks([])
        plt.figure(figsize=(7, 7))
        plt.imshow(features["next_frame"].numpy()); plt.yticks([]); plt.xticks([])
        plt.show(); b=d






