import os, cv2
import pickle
import glob

import torch
from torch.utils.data import Dataset
import numpy as np
import gym.spaces

from utils.logger import logger
from utils.gym_env import get_non_absorbing_state, get_absorbing_state, zero_value
from collections import deque

class ExpertDataset(Dataset):
    """ Dataset class for Imitation Learning. """

    def __init__(
        self,
        path,
        subsample_interval=1,
        ac_space=None,
        train=True,
        transform=None,
        target_transform=None,
        download=False,
        use_low_level=False,
        sample_range_start=0.0,
        sample_range_end=1.0,
        num_task=4,
        target_taskID=None,
        with_taskID=False,
        num_target_demos=None, # for few-shot learning, if None, use all demos, else randomly sample and use num_target_demos
        target_demo_path=None,
        is_sqil = False,
        frame_stack = None,
        encoder_image_size = None,
        sampled_few_demo_index = None, # this is given when we do heatmap for maze2d, and we need to load the same demos as training
    ):
        self.train = train  # training set or test set
        self.sampled_few_demo_index = sampled_few_demo_index

        self._data = []
        self._ac_space = ac_space

        assert (
                path is not None
        ), "--demo_path should be set (e.g. demos/Sawyer_toy_table)"

        path = path.split("#")
        demo_files_ = []
        for _p in path:
            f_ = self._get_demo_files(_p)
            demo_files_.append(f_)
        demo_files = [item for sublist in demo_files_ for item in sublist]

        self.one_hot_encode = []
        for x in range(num_task):
            arr = list(np.zeros(num_task, dtype=int))
            arr[x] = 1
            self.one_hot_encode.append(arr)
        task_id = 0
        if target_taskID:
            self.one_hot_encode[num_task-1] = target_taskID

        num_demos = 0
        self.num_files = len(demo_files)
        # now load the picked numpy arrays
        for file_path in demo_files:
            is_target_demo = (file_path == target_demo_path)
            if target_taskID is not None and file_path == demo_files[self.num_files-1]:
                task_id = num_task-1
            with open(file_path, "rb") as f:
                demos = pickle.load(f)
                if not isinstance(demos, list):
                    demos = [demos]

                if is_target_demo and num_target_demos is not None:
                    if self.sampled_few_demo_index is None:
                        demo_indx = np.random.choice(len(demos), size=num_target_demos, replace=False) # ensures that the selection does not include any duplicates.
                        self.sampled_few_demo_index = demo_indx
                    
                    demo_temp = [demos[i] for i in self.sampled_few_demo_index]
                    demos = demo_temp

                for demo in demos:
                    if len(demo["obs"]) != len(demo["actions"]) + 1:
                        logger.error(
                            "Mismatch in # of observations (%d) and actions (%d) (%s)",
                            len(demo["obs"]),
                            len(demo["actions"]),
                            file_path,
                        )
                        continue
                    
                    demo_clone = demo.copy()
                    
                    ### stack frame_stack number of obs in demos. for the first obs, repeat frame_stack times.
                    if frame_stack is not None:
                       demo_obs_temp = self.stack_frames(demo_clone, frame_stack, encoder_image_size)
                       demo_clone["obs"] = demo_obs_temp

                    offset = 2 #np.random.randint(0, subsample_interval)
                    num_demos += 1

                    if use_low_level:
                        length = len(demo["low_level_actions"])
                        start = int(length * sample_range_start)
                        end = int(length * sample_range_end)
                        for i in range(start + offset, end, subsample_interval):
                            transition = {
                                "ob": demo["low_level_obs"][i],
                                "ob_next": demo["low_level_obs"][i + 1],
                            }
                            if isinstance(demo["low_level_actions"][i], dict):
                                transition["ac"] = demo["low_level_actions"][i]
                            else:
                                transition["ac"] = gym.spaces.unflatten(
                                    ac_space, demo["low_level_actions"][i]
                                )

                            transition["done"] = 1 if i + 1 == length else 0

                            self._data.append(transition)

                        continue

                    length = len(demo_clone["actions"])
                    start = int(length * sample_range_start)
                    end = int(length * sample_range_end)
                    for i in range(start + offset, end, subsample_interval):
                        transition = {
                            "ob": demo_clone["obs"][i],
                            "ob_next": demo_clone["obs"][i + 1],
                            "traj_start_indicator": 1 if i == start + offset else 0, # 1 if the transition is the start of the trajectory
                        }
                        if isinstance(demo_clone["actions"][i], dict):
                            transition["ac"] = demo_clone["actions"][i]
                        else:
                            transition["ac"] = gym.spaces.unflatten(
                                ac_space, demo_clone["actions"][i]
                            )
                        if "rewards" in demo_clone:
                            transition["rew"] = demo_clone["rewards"][i] if not is_sqil else float(is_sqil and is_target_demo)
                        else:
                            transition["rew"] = float(is_sqil and is_target_demo)
                        if "dones" in demo_clone:
                            transition["done"] = int(demo_clone["dones"][i])
                        else:
                            transition["done"] = 1 if i + 1 == length else 0
                        if with_taskID:
                            transition["id"] = np.array(self.one_hot_encode[task_id])

                        self._data.append(transition)
            task_id += 1

        logger.warn( "Load %d demonstrations with %d states from %d files (%s)", num_demos, len(self._data), len(demo_files), path)

    def add_absorbing_states(self, ob_space, ac_space):
        new_data = []
        absorbing_state = get_absorbing_state(ob_space)
        absorbing_action = zero_value(ac_space, dtype=np.float32)
        for i in range(len(self._data)):
            transition = self._data[i].copy()
            transition["ob"] = get_non_absorbing_state(self._data[i]["ob"])
            # learn reward for the last transition regardless of timeout (different from paper)
            if self._data[i]["done"]:
                transition["ob_next"] = absorbing_state
                transition["done_mask"] = 0  # -1 absorbing, 0 done, 1 not done
            else:
                transition["ob_next"] = get_non_absorbing_state(
                    self._data[i]["ob_next"]
                )
                transition["done_mask"] = 1  # -1 absorbing, 0 done, 1 not done
            new_data.append(transition)

            if self._data[i]["done"]:
                transition = {
                    "ob": absorbing_state,
                    "ob_next": absorbing_state,
                    "ac": absorbing_action,
                    # "rew": np.float64(0.0),
                    "done": 0,
                    "done_mask": -1,  # -1 absorbing, 0 done, 1 not done
                }
                new_data.append(transition)

        self._data = new_data

    def stack_frames(self, demo_clone, frame_stack, encoder_image_size):
        # resize the channel-first image to 84x84
        if encoder_image_size is not None:
            for i in range(len(demo_clone["obs"])):
                temp_img = np.transpose(demo_clone["obs"][i]['ob'], (1, 2, 0))
                temp_img = cv2.resize(temp_img, (encoder_image_size, encoder_image_size), interpolation=cv2.INTER_AREA)
                demo_clone["obs"][i]['ob'] = np.transpose(temp_img, (2, 0, 1))

        self._frames = deque([], maxlen=frame_stack)
        for _ in range(frame_stack):
            self._frames.append(demo_clone["obs"][0])

        demo_obs_temp = []
        demo_obs_temp.append({'ob': np.concatenate([f['ob'] for f in self._frames], axis=0)})

        for i in range(1, len(demo_clone["obs"])):
            self._frames.append(demo_clone["obs"][i])
            demo_obs_temp.append({'ob': np.concatenate([f['ob'] for f in self._frames], axis=0)})
            if demo_obs_temp[-1]['ob'].shape[0] != 3*frame_stack:
                logger.error(
                    "Mismatch in # of channels (%d) and frame_stack (%d) (%s)",
                    demo_obs_temp[-1]['ob'].shape[0],
                    3*frame_stack,
                )
                continue

        # save the stacked images and check if loss any information --> no, can obviously see the agent doing continous actions in 10 stacked frames
        # import PIL
        # for i in range(0, len(demo["obs"]), 10):
        #     img1 = demo["obs"][i][:, :, :3]..... # 10 frames
        #     # save three img in one row
        #     img = np.concatenate((img1, img2, img3, img4, img5, img6, img7, img8, img9, img10), axis=1)
        #     img = img.astype(np.uint8)
        #     img = PIL.Image.fromarray(img)
        #     img.save(f"img_{i}.png")
            
        return demo_obs_temp

    def _get_demo_files(self, demo_file_path):
        demos = []
        if not demo_file_path.endswith(".pkl"):
            demo_file_path = demo_file_path + "*.pkl"
        for f in glob.glob(demo_file_path):
            if os.path.isfile(f):
                demos.append(f)
        return demos

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (ob, ac) where target is index of the target class.
        """
        return self._data[index]

    def __len__(self):
        return len(self._data)

class ExpertDataset_Episode(Dataset):
    """ Dataset class for Imitation Learning. """

    def __init__(
        self,
        path,
        subsample_interval=1,
        ac_space=None,
        train=True,
        transform=None,
        target_transform=None,
        download=False,
        use_low_level=False,
        sample_range_start=0.0,
        sample_range_end=1.0,
        num_task=4,
        target_taskID=None,
    ):
        self.train = train  # training set or test set

        self._data = []
        self._ac_space = ac_space

        assert (path is not None), "--demo_path should be set (e.g. demos/Sawyer_toy_table)"

        path = path.split("#")
        demo_files_ = []
        for _p in path:
            f_ = self._get_demo_files(_p)
            demo_files_.append(f_)
        demo_files = [item for sublist in demo_files_ for item in sublist]

        self.one_hot_encode = []
        for x in range(num_task):
            arr = list(np.zeros(num_task, dtype=int))
            arr[x] = 1
            self.one_hot_encode.append(arr)
        task_id = 0
        if target_taskID:
            self.one_hot_encode[num_task-1] = target_taskID

        num_demos = 0
        self.num_files = len(demo_files)
        # now load the picked numpy arrays
        for file_path in demo_files:
            if target_taskID is not None and file_path == demo_files[self.num_files-1]:
                task_id = num_task-1
            with open(file_path, "rb") as f:
                demos = pickle.load(f)
                if not isinstance(demos, list):
                    demos = [demos]

                for demo in demos:
                    if len(demo["obs"]) != len(demo["actions"]) + 1:
                        logger.error(
                            "Mismatch in # of observations (%d) and actions (%d) (%s)",
                            len(demo["obs"]),
                            len(demo["actions"]),
                            file_path,
                        )
                        continue
                    num_demos += 1

                    transition = {"ob": [x for x in demo["obs"][:]],}

                    if isinstance(demo["actions"][0], dict):
                        transition["ac"] = [x for x in demo["actions"][:]]

                    if "dones" in demo:
                        transition["done"] = [int(x) for x in demo["dones"][:]]

                    self._data.append(transition)
            task_id += 1

        logger.warn(
            "Load %d demonstrations with %d states from %d files",
            num_demos,
            len(self._data),
            len(demo_files),
        )

    def _get_demo_files(self, demo_file_path):
        demos = []
        if not demo_file_path.endswith(".pkl"):
            demo_file_path = demo_file_path + "*.pkl"
        for f in glob.glob(demo_file_path):
            if os.path.isfile(f):
                demos.append(f)
        return demos

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (ob, ac) where target is index of the target class.
        """
        return self._data[index]

    def __len__(self):
        return len(self._data)


class ExpertDataset_H_steps(Dataset):
    """ Dataset class for Imitation Learning. """

    def __init__(
        self,
        path,
        subsample_interval=10,
        ac_space=None,
        train=True,
        transform=None,
        target_transform=None,
        download=False,
        use_low_level=False,
        sample_range_start=0.0,
        sample_range_end=1.0,
        num_task=4,
        target_taskID=None,
        step_size=10,
        batch_size=None,
        num_target_demos=None, # for few-shot learning, if None, use all demos, else randomly sample and use num_target_demos
        target_demo_path=None,
        sampled_few_demo_index = None,
        encoder_image_size = None,
    ):
        self.train = train  # training set or test set

        self._data = []
        self._ac_space = ac_space

        assert (
                path is not None
        ), "--demo_path should be set (e.g. demos/Sawyer_toy_table)"

        path = path.split("#")
        demo_files_ = []
        for _p in path:
            f_ = self._get_demo_files(_p)
            demo_files_.append(f_)
        demo_files = [item for sublist in demo_files_ for item in sublist]

        self.one_hot_encode = []
        for x in range(num_task):
            arr = list(np.zeros(num_task, dtype=int))
            arr[x] = 1
            self.one_hot_encode.append(arr)
        task_id = 0
        if target_taskID:
            self.one_hot_encode[num_task-1] = target_taskID

        num_demos = 0
        self.num_files = len(demo_files)
        # now load the picked numpy arrays
        for file_path in demo_files:
            if target_taskID is not None and file_path == demo_files[self.num_files-1]:
                task_id = num_task-1
            with open(file_path, "rb") as f:
                demos = pickle.load(f)
                if not isinstance(demos, list):
                    demos = [demos]
                
                if file_path == target_demo_path and num_target_demos is not None:
                    demo_temp = [demos[i] for i in sampled_few_demo_index]
                    demos = demo_temp

                for demo in demos:
                    if len(demo["obs"]) != len(demo["actions"]) + 1:
                        logger.error(
                            "Mismatch in # of observations (%d) and actions (%d) (%s)",
                            len(demo["obs"]),
                            len(demo["actions"]),
                            file_path,
                        )
                        continue
                    num_demos += 1

                    length = len(demo["actions"])
                    if length-2 < step_size: # we don't use first two steps becasue they may be different from other steps due to the random initializations of the environment
                        continue # the env takes some steps to initialize the environment
                    
                    if encoder_image_size is not None:
                        for i in range(length):
                            temp_img = np.transpose(demo["obs"][i]['ob'], (1, 2, 0))
                            temp_img = cv2.resize(temp_img, (encoder_image_size, encoder_image_size), interpolation=cv2.INTER_AREA)
                            demo["obs"][i]['ob'] = np.transpose(temp_img, (2, 0, 1))
                    # sample step_size steps every subsample_interval steps
                    i = 2
                    while i + step_size < length:
                        transition = {
                            "ob": {"ob": np.concatenate(([np.expand_dims(x["ob"], axis=0) for x in demo["obs"][i: i + step_size]]), axis=0)}
                        }
                        assert transition["ob"]["ob"].shape[0] == step_size

                        if isinstance(demo["actions"][i], dict):
                            transition["ac"] = {"ac": np.concatenate(([np.expand_dims(x["ac"], axis=0) for x in demo["actions"][i: i + step_size]]))}

                        assert transition["ac"]["ac"].shape[0] == step_size

                        if "dones" in demo:
                            transition["done"] = [int(x) for x in demo["dones"][i: i + step_size]]
                        
                        transition["id"] = np.array(self.one_hot_encode[task_id])

                        self._data.append(transition)
                        i += subsample_interval


                    transition = {
                        "ob": {"ob": np.concatenate(([np.expand_dims(x["ob"], axis=0) for x in demo["obs"][length-step_size: length]]))}
                    }
                    assert transition["ob"]["ob"].shape[0] == step_size

                    if isinstance(demo["actions"][0], dict):
                        transition["ac"] = {"ac": np.concatenate(([np.expand_dims(x["ac"], axis=0) for x in demo["actions"][length-step_size: length]]))}
    
                    assert transition["ac"]["ac"].shape[0] == step_size

                    if "dones" in demo:
                        transition["done"] = [int(x) for x in demo["dones"][length-step_size: length]]
                
                    transition["id"] = np.array(self.one_hot_encode[task_id])

                    self._data.append(transition)


            task_id += 1

        # In case batch size, is larger than dataset size (few-shot learning), record batch size to manage
        self._batch_size = None
        if batch_size and batch_size > len(self._data):
            self._batch_size = batch_size

        logger.warn(
            "Load %d demonstrations with %d states from %d files",
            num_demos,
            len(self._data),
            len(demo_files),
        )

    def add_absorbing_states(self, ob_space, ac_space):
        new_data = []
        absorbing_state = get_absorbing_state(ob_space)
        absorbing_action = zero_value(ac_space, dtype=np.float32)
        for i in range(len(self._data)):
            transition = self._data[i].copy()
            transition["ob"] = get_non_absorbing_state(self._data[i]["ob"])
            # learn reward for the last transition regardless of timeout (different from paper)
            if self._data[i]["done"]:
                transition["ob_next"] = absorbing_state
                transition["done_mask"] = 0  # -1 absorbing, 0 done, 1 not done
            else:
                transition["ob_next"] = get_non_absorbing_state(
                    self._data[i]["ob_next"]
                )
                transition["done_mask"] = 1  # -1 absorbing, 0 done, 1 not done
            new_data.append(transition)

            if self._data[i]["done"]:
                transition = {
                    "ob": absorbing_state,
                    "ob_next": absorbing_state,
                    "ac": absorbing_action,
                    # "rew": np.float64(0.0),
                    "done": 0,
                    "done_mask": -1,  # -1 absorbing, 0 done, 1 not done
                }
                new_data.append(transition)

        self._data = new_data

    def _get_demo_files(self, demo_file_path):
        demos = []
        if not demo_file_path.endswith(".pkl"):
            demo_file_path = demo_file_path + "*.pkl"
        for f in glob.glob(demo_file_path):
            if os.path.isfile(f):
                demos.append(f)
        return demos

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (ob, ac) where target is index of the target class.
        """
        if self._batch_size:
            return self._data[index % len(self._data)]
        else:
            return self._data[index]

    def __len__(self):
        if self._batch_size:
            return self._batch_size
        else:
            return len(self._data)
