import os
import pathlib
import random
import json

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from collections import deque
from tqdm import tqdm


train_tasks = [
  'button-press-wall-v2', 'button-press-topdown-v2', 'button-press-v2',
  'handle-press-v2', 'handle-pull-side-v2', 'handle-pull-v2',
  'plate-slide-v2', 'plate-slide-back-v2', 'plate-slide-back-side-v2',
  'pick-out-of-hole-v2', 'pick-place-v2', 'push-wall-v2', 
  'push-back-v2', 'coffee-push-v2', 'coffee-pull-v2',
  'faucet-close-v2', 'faucet-open-v2', 'drawer-close-v2', 
  'drawer-open-v2', 'reach-v2',
]
test_tasks = ['button-press-topdown-wall-v2', 'handle-press-side-v2', 'pick-place-wall-v2', 'push-v2']


def convert(value):
    value = np.array(value)
    if np.issubdtype(value.dtype, np.floating):
        return value.astype(np.float32)
    elif np.issubdtype(value.dtype, np.signedinteger):
        return value.astype(np.int32)
    elif np.issubdtype(value.dtype, np.uint8):
        return value.astype(np.uint8)
    return value


def cycle(dl):
    while True:
        for data in dl:
            yield data
    

def list_files(root, tasks):
    paths = []
    dirs = os.listdir(root)
    for dir in dirs:
        if dir in tasks:
            paths.append(pathlib.Path(os.path.join(root, dir)))
    return paths


class InitialTrajDataset(Dataset):
    def __init__(self, path, seq_len=10, indices_path=None, mode="train", device="cpu"):
        super().__init__()
        self._path = pathlib.Path(path)
        if mode == "train":
            self.tasks = train_tasks
        elif mode == "test":
            self.tasks = test_tasks
        else:
            self.tasks = [mode]
        self._mode = mode
        self._paths = list_files(self._path, self.tasks)
        self._seq_len = seq_len
        self._action_dim = None

        # create a list of episode indices
        if indices_path is None:
            self.indices, self._num_episodes = self.load_indices()
            with open(self._path / f'{mode}_indices.json', 'w') as file:
                json.dump({'indices': self.indices, 'num_episodes': self._num_episodes}, file)
        else:
            try:
                with open(indices_path, 'r') as f:
                    info = json.load(f)
                    self.indices, self._num_episodes = info['indices'], info['num_episodes']
            except:
                self.indices, self._num_episodes = self.load_indices()
                with open(self._path / f'{mode}_indices.json', 'w') as file:
                    json.dump({'indices': self.indices, 'num_episodes': self._num_episodes}, file)

        print(f'num_episodes: {self._num_episodes}, num_steps: {len(self.indices)}')

        self.device = torch.device(device)

    def load_indices(self):
        indices = []
        num_episodes = 0
        for file_path in self._paths:
            filenames = os.listdir(file_path)
            random.Random(0).shuffle(filenames)
            for filename in tqdm(filenames, desc='Loading indices'):
                if filename.endswith('.npz'):
                    continue
                ep_len = len(os.listdir(os.path.join(file_path, filename)))
                if ep_len < self._seq_len:
                    continue
                num_episodes += 1
                for j in range(min(ep_len - self._seq_len + 1, 100)):
                    indices.append((str(file_path).split('/')[-1], int(filename), j, j+self._seq_len))

        return indices, num_episodes
    
    def load_traj(self, file_path, ep_idx, start, end):
        file_dir = os.path.join(str(self._path), file_path, str(ep_idx))
        episode_len = len(os.listdir(file_dir))

        imgs, actions, rewards = [], [], []
        for i in range(start, end):
            if i < episode_len:
                step = np.load(os.path.join(file_dir, f'{i}.npz'))
            else:
                step = np.load(os.path.join(file_dir, f'{episode_len-1}.npz'))
            imgs.append(step['imgs'])
            actions.append(step['actions'])
            rewards.append(step['rewards'])
        imgs, actions, rewards = np.array(imgs), np.array(actions), np.array(rewards)
        init_img = np.load(os.path.join(file_dir, '0.npz'))['imgs']
        if imgs.shape[-1] == 3:
            imgs = imgs.transpose(0, 3, 1, 2)
        init_img = init_img.transpose(2, 0, 1)
        init_img = (init_img/127.5 - 1.0).astype(np.float32)
        imgs = (imgs/127.5 - 1.0).astype(np.float32)
        actions = actions.astype(np.float32)
        rewards = rewards.astype(np.float32)
        rewards = np.sign(rewards) * (np.log(np.abs(rewards) + 1))

        return init_img, imgs, actions, rewards
    
    def __getitem__(self, idx):
        file_path, ep_idx, start, end = self.indices[idx]
        init_img, imgs, actions, rewards = self.load_traj(file_path, ep_idx, start, end)
        return init_img, imgs, actions, rewards

    def __len__(self):
        return len(self.indices)
    
    @property
    def action_dim(self):
        if self._action_dim is None:
            _, imgs, actions, rewards = self[0]
            self._action_dim = actions.shape[-1]
        return self._action_dim
    
    def sample_per_task(self, sample_num=1, sample_len=10):
        data = {}
        for file_path in self._paths:
            traj_num = len(os.listdir(file_path))
            num = 0
            init_imgs, imgs, actions, rewards = [], [], [], []
            while num < sample_num:
                ep_idx = np.random.randint(1, traj_num+1)
                if ep_idx > 3000:
                    continue
                file_dir = os.path.join(file_path, str(ep_idx))
                ep_len = len(os.listdir(file_dir))
                if ep_len > sample_len:
                    # start_idx = np.random.randint(ep_len - sample_len)
                    start_idx = 0
                    init_img, img, action, reward = self.load_traj(str(file_path).split('/')[-1], ep_idx, start_idx, start_idx+sample_len)
                    init_imgs.append(init_img)
                    imgs.append(img)
                    actions.append(action)
                    rewards.append(reward)
                    num += 1
            data[str(file_path).split('/')[-1]] = {
                'init_imgs': np.array(init_imgs),
                'imgs': np.array(imgs),
                'actions': np.array(actions),
                'rewards': np.array(rewards)
            }
        return data
        
    def sample_batch_dataset(self, num=1, horizon=32):
        results = {"imgs": [], "acts": [], "rewards": []}
        task_names = []
        for file_path in self._paths:
            try:
                for n in range(num):
                    imgs, acts, rewards = self.load_traj(self._path / file_path, n + 1, 0, horizon)
                    results["imgs"].append(imgs)
                    results["acts"].append(acts)
                    results["rewards"].append(rewards)
                    task_names.append(str(file_path))
            except:
                continue
        imgs = torch.from_numpy(np.stack(results['imgs'])).to(self.device)
        actions = torch.from_numpy(np.stack(results['acts'])).to(self.device)
        rewards = torch.from_numpy(np.stack(results['rewards'])).to(self.device)
        return imgs, actions, rewards, task_names
    
    def sample_data(self, sample_len=32):
        data = self.sample_per_task(sample_len=sample_len)
        imgs = torch.FloatTensor(np.concatenate([v['imgs'] for v in data.values()])).to(self.device)
        actions = torch.FloatTensor(np.concatenate([v['actions'] for v in data.values()])).to(self.device)
        rewards = torch.FloatTensor(np.concatenate([v['rewards'] for v in data.values()])).to(self.device)
        return imgs, actions, rewards

class TrajDataset(Dataset):
    def __init__(self, path, seq_len=10, indices_path=None, mode="train", device="cpu"):
        super().__init__()
        self._path = pathlib.Path(path)
        if mode == "train":
            self.tasks = train_tasks
        elif mode == "test":
            self.tasks = test_tasks
        else:
            self.tasks = [mode]
        self._mode = mode
        self._paths = list_files(self._path, self.tasks)
        self._seq_len = seq_len
        self._action_dim = None

        # create a list of episode indices
        if indices_path is None:
            self.indices, self._num_episodes = self.load_indices()
            with open(self._path / f'{mode}_indices.json', 'w') as file:
                json.dump({'indices': self.indices, 'num_episodes': self._num_episodes}, file)
        else:
            try:
                with open(indices_path, 'r') as f:
                    info = json.load(f)
                    self.indices, self._num_episodes = info['indices'], info['num_episodes']
            except:
                self.indices, self._num_episodes = self.load_indices()
                with open(self._path / f'{mode}_indices.json', 'w') as file:
                    json.dump({'indices': self.indices, 'num_episodes': self._num_episodes}, file)

        print(f'num_episodes: {self._num_episodes}, num_steps: {len(self.indices)}')

        self.device = torch.device(device)

    def load_indices(self):
        indices = []
        num_episodes = 0
        for file_path in self._paths:
            filenames = os.listdir(file_path)
            random.Random(0).shuffle(filenames)
            for filename in tqdm(filenames, desc='Loading indices'):
                if filename.endswith('.npz'):
                    continue
                ep_len = len(os.listdir(os.path.join(file_path, filename)))
                if ep_len < self._seq_len:
                    continue
                num_episodes += 1
                for j in range(min(ep_len - self._seq_len + 1, 100)):
                    indices.append((str(file_path).split('/')[-1], int(filename), j, j+self._seq_len))

        return indices, num_episodes
    
    def load_traj(self, file_path, ep_idx, start, end):
        file_dir = os.path.join(str(self._path), file_path, str(ep_idx))
        episode_len = len(os.listdir(file_dir))

        imgs, actions, rewards = [], [], []
        for i in range(start, end):
            if i < episode_len:
                step = np.load(os.path.join(file_dir, f'{i}.npz'))
            else:
                step = np.load(os.path.join(file_dir, f'{episode_len-1}.npz'))
            imgs.append(step['imgs'])
            actions.append(step['actions'])
            rewards.append(step['rewards'])
        imgs, actions, rewards = np.array(imgs), np.array(actions), np.array(rewards)
        if imgs.shape[-1] == 3:
            imgs = imgs.transpose(0, 3, 1, 2)
        imgs = (imgs/127.5 - 1.0).astype(np.float32)
        actions = actions.astype(np.float32)
        rewards = rewards.astype(np.float32)
        rewards = np.sign(rewards) * (np.log(np.abs(rewards) + 1))

        return imgs, actions, rewards
    
    def __getitem__(self, idx):
        file_path, ep_idx, start, end = self.indices[idx]
        imgs, actions, rewards = self.load_traj(file_path, ep_idx, start, end)
        return imgs, actions, rewards

    def __len__(self):
        return len(self.indices)
    
    @property
    def action_dim(self):
        if self._action_dim is None:
            imgs, actions, rewards = self[0]
            self._action_dim = actions.shape[-1]
        return self._action_dim
    
    def sample_per_task(self, sample_num=1, sample_len=10):
        data = {}
        for file_path in self._paths:
            traj_num = len(os.listdir(file_path))
            num = 0
            imgs, actions, rewards = [], [], []
            while num < sample_num:
                ep_idx = np.random.randint(1, traj_num+1)
                if ep_idx > 3000:
                    continue
                file_dir = os.path.join(file_path, str(ep_idx))
                ep_len = len(os.listdir(file_dir))
                if ep_len > sample_len:
                    # start_idx = np.random.randint(ep_len - sample_len)
                    start_idx = 0
                    img, action, reward = self.load_traj(str(file_path).split('/')[-1], ep_idx, start_idx, start_idx+sample_len)
                    imgs.append(img)
                    actions.append(action)
                    rewards.append(reward)
                    num += 1
            data[str(file_path).split('/')[-1]] = {
                'imgs': np.array(imgs),
                'actions': np.array(actions),
                'rewards': np.array(rewards)
            }
        return data
        
    def sample_batch_dataset(self, num=1, horizon=32):
        results = {"imgs": [], "acts": [], "rewards": []}
        task_names = []
        for file_path in self._paths:
            try:
                for n in range(num):
                    imgs, acts, rewards = self.load_traj(self._path / file_path, n + 1, 0, horizon)
                    results["imgs"].append(imgs)
                    results["acts"].append(acts)
                    results["rewards"].append(rewards)
                    task_names.append(str(file_path))
            except:
                continue
        imgs = torch.from_numpy(np.stack(results['imgs'])).to(self.device)
        actions = torch.from_numpy(np.stack(results['acts'])).to(self.device)
        rewards = torch.from_numpy(np.stack(results['rewards'])).to(self.device)
        return imgs, actions, rewards, task_names
    
    def sample_data(self, sample_len=32):
        data = self.sample_per_task(sample_len=sample_len)
        imgs = torch.FloatTensor(np.concatenate([v['imgs'] for v in data.values()])).to(self.device)
        actions = torch.FloatTensor(np.concatenate([v['actions'] for v in data.values()])).to(self.device)
        rewards = torch.FloatTensor(np.concatenate([v['rewards'] for v in data.values()])).to(self.device)
        return imgs, actions, rewards

    
class EvalTrajDataset(Dataset):
    def __init__(self, path, seq_len=32, mode="train", device="cpu"):
        super().__init__()
        self._path = path
        self.mode = mode
        if mode == "train":
            self.tasks = train_tasks
        elif mode == "test":
            self.tasks = test_tasks
        else:
            self.tasks = [mode]
        self._mode = mode
        self._paths = list_files(pathlib.Path(path), self.tasks)
        self._seq_len = seq_len
        self._action_dim = None

        self.indices, self._num_episodes = self.load_indices()

        print(f'num_episodes: {self._num_episodes}, num_steps: {len(self.indices)}')

        self.device = torch.device(device)

    def load_indices(self):
        indices_path = os.path.join(self._path, f'{self.mode}_indices_{self._seq_len}.json')
        if os.path.exists(indices_path):
            with open(indices_path, 'r') as f:
                info = json.load(f)
            return info['indices'], info['num_episodes']
        indices = []
        num_episodes = 0
        for file_path in self._paths:
            filenames = os.listdir(file_path)
            random.Random(0).shuffle(filenames)
            for filename in tqdm(filenames, desc='Loading indices'):
                if filename.endswith('.npz'):
                    continue
                file_dir = os.path.join(file_path, filename)
                ep_len = len(os.listdir(file_dir))
                if ep_len >= self._seq_len:
                    indices.append((str(file_path), int(filename), ep_len))
                num_episodes += 1
        random.Random(0).shuffle(indices)
        with open(indices_path, 'w') as file:
            json.dump({'indices': indices, 'num_episodes': num_episodes}, file)

        return indices, num_episodes
    
    def load_traj(self, file_path, ep_idx, ep_len):
        file_dir = os.path.join(str(self._path), file_path, str(ep_idx))
        
        imgs, actions, rewards = [], [], []
        for i in range(self._seq_len):
            step = np.load(os.path.join(file_dir, f'{i}.npz'))
            imgs.append(step['imgs'])
            actions.append(step['actions'])
            rewards.append(step['rewards'])
        imgs, actions, rewards = np.array(imgs), np.array(actions), np.array(rewards)
        if imgs.shape[-1] == 3:
            imgs = imgs.transpose(0, 3, 1, 2)
        imgs = (imgs/127.5 - 1.0).astype(np.float32)
        actions = actions.astype(np.float32)
        rewards = rewards.astype(np.float32)
        rewards = np.sign(rewards) * (np.log(np.abs(rewards) + 1))

        return imgs, actions, rewards
    
    def __getitem__(self, idx):
        file_path, ep_idx, ep_len = self.indices[idx]
        imgs, actions, rewards = self.load_traj(file_path, ep_idx, ep_len)
        return imgs, actions, rewards

    def __len__(self):
        return len(self.indices)

    @property
    def action_dim(self):
        if self._action_dim is None:
            imgs, actions, rewards = self[0]
            self._action_dim = actions.shape[-1]
        return self._action_dim
