import numpy as np
from tqdm import tqdm
from sgcrl.data.episodes_readers.base_episodes_readers import EpisodesReader
import joblib

from absl import logging

import joblib
import numpy as np
import pickle

import os

import torch
import io

{'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_60_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}
LOCAL_LOG_DIR = 'stable_contrastive_rl'
PICKLE = 'pickle'
NUMPY = 'numpy'
JOBLIB = 'joblib'
TORCH = 'torch'

def sync_down(path, check_exists=True):

    local_path = "%s/%s" % (LOCAL_LOG_DIR, path)

    if check_exists and os.path.isfile(local_path):
        return local_path

def local_path_from_s3_or_local_path(filename):
    relative_filename = os.path.join(LOCAL_LOG_DIR, filename)

    if os.path.isfile(filename):
        return filename
    elif os.path.isfile(relative_filename):
        return relative_filename
    else:
        logging.warn('Cannot find the file: %r' % (filename))
        # exit()
        return sync_down(filename)

def load_local_or_remote_file(filepath, file_type=None, delete_after_loading=False, **kwargs):
    print('filepath:', filepath)
    local_path = local_path_from_s3_or_local_path(filepath)
    print('local_path:', local_path)
    assert local_path is not None, "file not found. Filename: %s" % filepath
    if file_type is None:
        extension = local_path.split('.')[-1]
        if extension == 'npy' or extension == 'npz':
            file_type = NUMPY
        elif extension == "pt":
            file_type = TORCH
        else:
            file_type = PICKLE
    if file_type == NUMPY:
        object = np.load(open(local_path, "rb"), allow_pickle=True)
    elif file_type == JOBLIB:
        object = joblib.load(local_path)
    elif file_type == TORCH:
        object = torch.load(local_path, **kwargs)
    else:
        #f = open(local_path, 'rb')
        #object = CPU_Unpickler(f).load()
        object = pickle.load(open(local_path, "rb"))
    print("loaded", local_path)

    if (local_path[:4] == "/tmp") and delete_after_loading:
        print("deleting tmp file after loading.")
        os.remove(local_path)

    return object

import matplotlib.pyplot as plt

class RoboverseEpisodesReader(EpisodesReader):
    def __init__(
        self,
        paths: list[str] = ['stable_contrastive_rl/data/env6_td_pnp_push_dataset/scene0_view0_mixed_0_demos.pkl'],
        env_name: str = None,
        max_episode_length:int = None
    ):
        # load numpy dataset
        self.dataset = []
        for filename in paths:
            data_i = load_local_or_remote_file(
                filename,
                delete_after_loading=False)
            data_i = list(data_i)
            self.dataset.extend(data_i)

        # get episodes
        ep_first_frame = 0
        self.lengths = []
        self.returns = []
        for i, episode in tqdm(enumerate(self.dataset), total=len(self.dataset), desc='reading episodes'):
                
            # fetch episode data
            rewards = episode['rewards']
            self.returns.append(np.sum(rewards))

            # store real lengths
            length = len(rewards)
            self.lengths.append(length)

        if max_episode_length is None: 
            self.max_episode_length = max(self.lengths)
        else:
            self.max_episode_length = max_episode_length
    
    def __getitem__(self, i):
        
        T = len(self.dataset[i]['observations'])

        # form episode
        episode = {
            'observations': np.stack([np.swapaxes(np.reshape(self.dataset[i]['observations'][t]['image_observation'],(3,48,48)),0,2) for t in range(T)],axis=0).astype(np.float32),
            'state_observations': np.stack([self.dataset[i]['observations'][t]['state_observation'] for t in range(T)],axis=0),
            'actions': np.stack([self.dataset[i]['actions'][t] for t in range(T)],axis=0),
            'rewards': np.stack([self.dataset[i]['rewards'][t] for t in range(T)],axis=0),
            'dones':  np.stack([self.dataset[i]['terminals'][t] for t in range(T)],axis=0),
            'masks': 1 - np.stack([self.dataset[i]['terminals'][t] for t in range(T)],axis=0),
            'next_observations': np.stack([np.swapaxes(np.reshape(self.dataset[i]['next_observations'][t]['image_observation'],(3,48,48)),0,2) for t in range(T)],axis=0).astype(np.float32),
            'next_state_observations': np.stack([self.dataset[i]['next_observations'][t]['state_observation'] for t in range(T)],axis=0),
        }
        
        # pad episode
        rest = self.max_episode_length - len(episode['actions'])
        padded_episode = {}
        for k,v in episode.items():
            padding = np.zeros(shape=(rest,*v.shape[1:]),dtype=v.dtype)
            padded_episode[k] = np.concatenate([v,padding],axis=0)
        
        # return padded_episode
        return padded_episode

    def __len__(self):
        return len(self.dataset)


if __name__ == '__main__':
    import numpy as np
    import cv2

    er = RoboverseEpisodesReader()
    episode = er[16]

    size = 48, 48
    duration = 2
    fps = 25
    out = cv2.VideoWriter('videos/roboverse.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (size[1], size[0]), True)
    for t in range(75):
        print(t)
        data = (episode['observations'][t] * 255).astype(np.uint8)
        out.write(data)
    out.release()




