import gym
import d4rl
import os
import torch
import numpy as np
from tqdm import tqdm
from sgcrl.data.episodes_readers.base_episodes_readers import EpisodesReader

class D4RLEpisodesReader(EpisodesReader):
    def __init__(
        self,
        env_name,
        split_obs=True, 
        clip_to_eps = False, 
        hiql_dones=True
    ):
        env = gym.make(env_name)

        # dataset contains observations, actions, rewards, terminals, and infos
        dataset = env.get_dataset()

        # terminated not used in antmaze d4rl datasets (it does not abort the episode)
        # but truncated not present when using q_learning datasets
        # done_key = "truncated" if "truncated" in dataset.keys() else "terminals"
        if not "timeouts" in dataset.keys():
            if not hiql_dones:
                # print("building dones IQL-style")
                dones = np.zeros_like(dataset['rewards'], dtype=bool)
                # terminals are erroneous, compute the dones like in the IQL codebase
                for i in range(len(dones) - 1):
                    if np.linalg.norm(dataset['observations'][i + 1] -
                                    dataset['next_observations'][i]
                                    ) > 1e-6 or dataset['terminals'][i] == 1.0:
                        dones[i] = True
                    else:
                        dones[i] = False
                dones[-1] = True
        else:
            dones = dataset["timeouts"]
        ep_first_frame = 0
        ep_lengths = []

        if clip_to_eps:
            lim = 1 - 1e-5
            # print(f"clipping actions in[-{lim},{lim}], as in IQL codebase")
            dataset['actions'] = np.clip(dataset['actions'], -lim, lim)

        tensor_dataset = []
        returns = []
        for idx, done in enumerate(dones):
            if done:
                rewards = dataset['rewards'][ep_first_frame:idx+1].copy()
                returns.append(np.sum(rewards))
                if "antmaze" in env_name:
                    rewards -= 1
                episode_frames = {'action': torch.tensor(dataset['actions'][ep_first_frame:idx+1].copy()).float(),
                                'rewards': torch.tensor(rewards).float(),
                                'reward': torch.tensor(rewards).float(),
                                'done': torch.tensor(dones[ep_first_frame:idx+1].copy()),
                                'masks': 1.0 - torch.tensor(dataset['terminals'][ep_first_frame:idx+1].copy()).float(),
                                'mask': 1.0 - torch.tensor(dataset['terminals'][ep_first_frame:idx+1].copy()).float()}
                if 'infos/goal' in dataset.keys():
                    episode_frames['goal'] = torch.tensor(dataset['infos/goal'][ep_first_frame:idx+1].copy())
                if split_obs:
                    episode_frames['obs/pos'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, :2].copy()).float()
                    episode_frames['obs/other'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, 2:].copy()).float()
                else:
                    episode_frames['obs'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, :].copy()).float()

                ep_lengths.append(len(episode_frames['action']))
                ep_first_frame = idx + 1
                tensor_dataset.append(episode_frames)

        self.episodes = tensor_dataset
    
    def get_ids(self):
        return np.arange(start=0,stop=len(self.episodes),step=1) 

class KitchenEpisodesReader(EpisodesReader):
    def __init__(
        self,
        env_name: str, 
    ):

        # load numpy qlearning dataset
        env = gym.make(env_name)
        dataset = d4rl.qlearning_dataset(env)

        dones_float = dataset['terminals']
        
        # restrict kitchen observation
        if 'kitchen' in env_name:
            dataset['infos/goal'] = torch.tensor(dataset['observations'][:, 30:])
            dataset['observation'] = torch.tensor(dataset['observations'][:, :30])
            dataset['next_observation'] = torch.tensor(dataset['next_observations'][:, :30])


        # get episodes
        ep_first_frame = 0
        self.lengths = []
        self.episodes = []
        returns = []
        for idx, done in enumerate(dones_float):
            if done:
                # form episode
                rewards = dataset['rewards'][ep_first_frame:idx+1]
                returns.append(np.sum(rewards))
                episode = {
                    'observation': dataset["observation"][ep_first_frame:idx + 1, :],
                    'next_observation': dataset["next_observation"][ep_first_frame:idx + 1, :],
                    'action': dataset['actions'][ep_first_frame:idx+1],
                    'rewards': rewards,
                    'dones': dones_float[ep_first_frame:idx+1],
                    'masks': 1.0 - dataset['terminals'][ep_first_frame:idx+1]}
                if 'infos/goal' in dataset.keys():
                    episode['goal'] = dataset['infos/goal'][ep_first_frame:idx+1]

                length = len(episode['action'])
                self.lengths.append(length)
                
                # store episode
                ep_first_frame = idx + 1
                self.episodes.append(episode)

        self.episodes_idxs = list(range(len(self.episodes)))

    def __len__(self):
        return len(self.episodes)

    def __getitem__(self, idx):
        return self.episodes[idx]
    
    def get_ids(self):
        return list(range(len(self.episodes_idxs)))
    
class VisualAntmazeEpisodesReader(EpisodesReader):
    def __init__(
        self,
        env_name: str,
        amz_dataset_dir: str = 'antmaze_topview_6_60',
        max_episode_length: int = None,
    ):
        # load numpy qlearning dataset
        orig_env_name = env_name.split('topview-')[1]
        print(f'loading visual antmaze image dataset: {amz_dataset_dir}/{env_name}')
        self.dataset = dict(np.load(os.path.abspath(f'datasets/{amz_dataset_dir}/{orig_env_name}.npz')))
        
        # get episodes
        ep_first_frame = 0
        self.lengths = []
        self.returns = []
        self.episodes_idxs = []
        dones_float = self.dataset['dones_float']
        for idx, done in tqdm(enumerate(dones_float), total=dones_float.shape[0], desc='reading transitions'):
            if done:
                
                # Store (start,end) idx, such that episode is within [start,end)
                self.episodes_idxs.append((ep_first_frame,idx+1))

                # fetch episode data
                rewards = self.dataset['rewards'][ep_first_frame:idx+1]
                self.returns.append(np.sum(rewards))

                # store real lengths
                length = len(rewards)
                self.lengths.append(length)

                # store episode
                ep_first_frame = idx + 1

        if max_episode_length is None: 
            self.max_episode_length = max(self.lengths)
        else:
            self.max_episode_length = max_episode_length
        
    def __getitem__(self, i):
        start, end = self.episodes_idxs[i]
    
        # form episode
        episode = {
            'observations': (self.dataset['images'][start:end] / 255).astype(np.float32),
            'state_observations': self.dataset['observations'][start:end],
            'actions': self.dataset['actions'][start:end],
            'rewards': self.dataset['rewards'][start:end],
            'dones': self.dataset['dones_float'][start:end],
            'masks': self.dataset['masks'][start:end],
            'next_observations': (self.dataset['next_images'][start:end] / 255).astype(np.float32),
            'next_state_observations': self.dataset['next_observations'][start:end],
            'goal': (np.repeat(self.dataset['images'][end-1:end], repeats=end-start, axis=0) / 255).astype(np.float32),
            'state_goal': np.repeat(self.dataset['observations'][end:end+1], repeats=end-start, axis=0),
            'infos': {'length': end - start}
        }
        
        # pad episode
        rest = self.max_episode_length - len(episode['actions'])
        padded_episode = {}
        for k,v in episode.items():
            if isinstance(v,np.ndarray):
                padding = np.zeros(shape=(rest,*v.shape[1:]),dtype=v.dtype)
                padded_episode[k] = np.concatenate([v,padding],axis=0)
            else:
                padded_episode[k] = v
        
        return padded_episode

    def get_episode_frame(self, i, t):
        start, end = self.episodes_idxs[i]
        frame_idx = start + t

        frame = {
            'observations': (self.dataset['images'][frame_idx] / 255).astype(np.float32),
            'state_observations': self.dataset['observations'][frame_idx],
            'actions': self.dataset['actions'][frame_idx],
            'rewards': self.dataset['rewards'][frame_idx],
            'dones': self.dataset['dones_float'][frame_idx],
            'masks': self.dataset['masks'][frame_idx],
            'next_observations': (self.dataset['next_images'][frame_idx] / 255).astype(np.float32),
            'next_state_observations': self.dataset['next_observations'][frame_idx],
            'infos': {'length': end - start}
        }

        return frame

    def get_ids(self):
        return np.arange(start=0,stop=len(self),step=1)
    
    def __len__(self):
        return len(self.episodes_idxs)