import os
from io import BytesIO
from typing import Sequence

import clip
import h5py
import numpy as np
import torch
from ml_collections import ConfigDict

from .instruct import get_furniturebench_instruct


class ARPFurnitureBenchDataset(torch.utils.data.Dataset):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()

        config.data_dir = ""
        config.env_type = "furnituresim"
        config.max_episode_steps = 500
        config.task_name = "one_leg"
        config.is_sim = False

        config.start_index = 0
        config.max_length = int(1e9)
        config.random_start = False

        config.image_size = 224

        config.image_keys = "color_image2"
        config.image_main_key = "color_image2"
        config.state_key = ""
        config.action_dim = 8
        config.clip_action = 0.999

        config.skip_frame = 16
        config.window_size = 2

        config.use_bert_tokenizer = False
        config.tokenizer_max_length = 77

        # TODO: Only for factor-world
        config.offset = 0
        config.num_demos = 100

        # Reward Learning option
        config.use_nfp = True
        config.use_liv = False
        config.output_type = "feature"
        config.pvr_type = "LIV"
        config.use_sparse = False

        # RTG related configs.
        config.rtg_key = ""
        config.rtg_scale = 1000.0

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, update, h5_file_name, start_offset_ratio=None, split="train"):
        self.config = self.get_default_config(update)
        self.split = split

        self.h5_file_name = h5_file_name
        assert h5_file_name is not None, "You must specify h5_file_name."
        self.h5_file = h5py.File(h5_file_name, "r")

        if self.config.random_start:
            self.random_start_offset = np.random.default_rng().choice(len(self))
        elif start_offset_ratio is not None:
            self.random_start_offset = int(len(self) * start_offset_ratio) % len(self)
        else:
            self.random_start_offset = 0

        self.suffix = "" if self.config.output_type == "raw" else f"_{self.config.pvr_type}"
        if self.suffix == "":
            self.tokenizer = self.build_tokenizer()
        if self.config.rtg_key != "":
            self._rtg_key = f"rtg_{self.config.rtg_key}"
            self._rtg_min, self._rtg_max = np.min(self.h5_file[self._rtg_key]), np.max(self.h5_file[self._rtg_key])
            print(f"[INFO] [{self.split}] {self._rtg_key}: {self._rtg_min} ~ {self._rtg_max}")

    def __getstate__(self):
        return self.config, self.random_start_offset, self.h5_file_name

    def __setstate__(self, state):
        config, random_start_offset, h5_file_name = state
        self.__init__(config, h5_file_name=h5_file_name)
        self.random_start_offset = random_start_offset

    def __len__(self):
        demo_indicators = np.unique(self.h5_file["demo_idx"])
        return min(demo_indicators[self.config.num_demos] - self.config.start_index, self.config.max_length)

    def process_index(self, index):
        index = (index + self.random_start_offset) % len(self)
        return index + self.config.start_index

    def process_rtg(self, rtg):
        # return (rtg - self._rtg_min) / self.config.rtg_scale
        return 2 * (rtg - self._rtg_min) / (self._rtg_max - self._rtg_min) - 1

    def build_tokenizer(self):
        use_bert_tokenizer = self.config.use_bert_tokenizer
        tokenizer_max_length = self.config.tokenizer_max_length

        if use_bert_tokenizer:
            import transformers

            tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
        else:
            tokenizer = clip.tokenize

        def tokenizer_fn(instruct):
            if use_bert_tokenizer:
                if len(instruct) == 0:
                    tokenized_instruct = np.zeros(tokenizer_max_length, dtype=np.int32)
                    padding_mask = np.ones(tokenizer_max_length, dtype=np.float32)
                else:
                    encoded_instruct = tokenizer(
                        instruct,
                        padding="max_length",
                        truncation=True,
                        max_length=tokenizer_max_length,
                        return_tensors="np",
                        add_special_tokens=False,
                    )
                    tokenized_instruct = encoded_instruct["input_ids"].astype(np.int32)
                    padding_mask = 1.0 - encoded_instruct["attention_mask"].astype(np.float32)
            else:
                tokenized_instruct = np.asarray(tokenizer(instruct)).astype(np.int32)
                padding_mask = np.ones(tokenizer_max_length).astype(np.float32)
            return tokenized_instruct, padding_mask

        return tokenizer_fn

    def get_tokenized_instructions(self, task_name: bytes, skills: Sequence):
        with BytesIO(task_name) as fin:
            decoded_task_name = fin.read().decode("utf-8")
        _instruct = []
        for idx in range(self.config.window_size):
            _instruct.append(get_furniturebench_instruct(decoded_task_name, np.asarray(skills[idx])))
        instruct, text_padding_mask = self.tokenizer(_instruct)
        return instruct, text_padding_mask

    def __getitem__(self, index):
        index = self.process_index(index)
        batch = {"image": {}}
        image_keys = self.config.image_keys.split("|")

        # LIV specific data: it makes process slower.
        if self.config.use_liv:
            batch.update({"initial_image": {}, "goal_image": {}, "next_image": {}})
            liv_next_index = min(min(index + 1, self.h5_file["demo_idx"][index][-1] - 1), len(self) - 1)
            for ik in image_keys:
                batch["initial_image"][ik] = self.h5_file[f"{ik}{self.suffix}"][self.h5_file[f"initial_{ik}"][index]]
                batch["goal_image"][ik] = self.h5_file[f"{ik}{self.suffix}"][self.h5_file[f"goal_{ik}"][index]]
                batch["next_image"][ik] = self.h5_file[f"{ik}{self.suffix}"][liv_next_index]
                if batch["initial_image"][ik].shape[1] == 3:
                    batch["initial_image"][ik] = batch["initial_image"][ik].transpose(0, 2, 3, 1)
                if batch["goal_image"][ik].shape[1] == 3:
                    batch["goal_image"][ik] = batch["goal_image"][ik].transpose(0, 2, 3, 1)
                if batch["next_image"][ik].shape[1] == 3:
                    batch["next_image"][ik] = batch["next_image"][ik].transpose(0, 2, 3, 1)

            batch["goal_attn_mask"] = self.h5_file["attn_mask"][self.h5_file[f"goal_{ik}"][index]]
            batch["goal_timestep"] = self.h5_file["timestep"][self.h5_file[f"goal_{ik}"][index]]
            batch["initial_attn_mask"] = self.h5_file["attn_mask"][self.h5_file[f"initial_{ik}"][index]]
            batch["initial_timestep"] = self.h5_file["timestep"][self.h5_file[f"initial_{ik}"][index]]
            batch["next_attn_mask"] = self.h5_file["attn_mask"][liv_next_index]
            batch["next_timestep"] = self.h5_file["timestep"][liv_next_index]

            batch["r"] = self.h5_file["r"][index]

        # batch item for next observations
        if self.config.use_nfp:
            batch.update({"nfp_next_image": {}})
            for ik in image_keys:
                batch["nfp_next_image"][ik] = self.h5_file[f"{ik}{self.suffix}"][self.h5_file[f"next_{ik}"][index]]
                if batch["nfp_next_image"][ik].shape[1] == 3:
                    batch["nfp_next_image"][ik] = batch["nfp_next_image"][ik].transpose(0, 2, 3, 1)
            # batch["next_skill"] = self.h5_file["skill"][self.h5_file[f"next_{ik}"][index]].astype(np.float32)

        # Batch data
        keys = ["action", "timestep", "skill", "attn_mask"]
        for key in keys:
            batch[key] = self.h5_file[key][index]
        for ik in image_keys:
            batch["image"][ik] = self.h5_file[f"{ik}{self.suffix}"][index]
            if batch["image"][ik].shape[1] == 3:
                batch["image"][ik] = batch["image"][ik].transpose(0, 2, 3, 1)

        # clip action for stabilizing.
        batch["action"] = np.clip(batch["action"], -self.config.clip_action, self.config.clip_action)

        # Text feature
        next_index = (
            index + self.config.skip_frame
            if self.h5_file["demo_number"][index]
            == self.h5_file["demo_number"][min(index + self.config.skip_frame, len(self) - 1)]
            else self.h5_file["demo_idx"][index][-1] - 1
        )
        next_index = min(next_index, len(self) - 1)
        if self.suffix == "":
            batch["instruct"], _ = self.get_tokenized_instructions(self.h5_file["task_name"][index], batch["skill"])
            # batch["next_instruct"], _ = self.get_tokenized_instructions(
            #     self.h5_file["task_name"][index], batch["next_skill"]
            # )
        else:
            batch["instruct"] = self.h5_file[f"instruct{self.suffix}"][index]
            # batch["next_instruct"] = self.h5_file[f"instruct_{self.config.pvr_type}"][next_index]

        # GT reward
        if self.config.use_sparse:
            batch["reward"] = self.h5_file["terminals"][index].astype(np.float32)
            # batch["next_reward"] = self.h5_file["terminals"][
            #     self.h5_file[f"next_{self.config.image_main_key}"][index]
            # ].astype(np.float32)
        else:
            batch["reward"] = batch["skill"].astype(np.float32)
            # batch["next_reward"] = batch["next_skill"].astype(np.float32)

        # deprecated: apply rtgs
        if self.config.rtg_key != "":
            batch["rtg"] = []
            batch["rtg"].extend(
                np.vectorize(self.process_rtg)(self.h5_file[f"rtg_{self.config.rtg_key}"][index][..., None])
            )
            batch["rtg"] = np.asarray(batch["rtg"])

        return batch

    @property
    def num_actions(self):
        return self.config.action_dim

    @property
    def rtg(self):
        if self.config.rtg_key != "":
            return np.quantile(self.h5_file[f"rtg_{self.config.rtg_key}"], 0.9)
        else:
            return None

    @property
    def rtg_scale(self):
        if self.config.rtg_key != "":
            return self.config.rtg_scale
        else:
            return None

    @property
    def obs_shape(self):
        res = {"image": {}}
        for key in self.config.image_keys.split("|"):
            res["image"][key] = (self.config.image_size, self.config.image_size, 3)
        res["rtg"] = (1,)
        if self.config.state_key != "":
            res["state"] = self.config.state_dim
        return res


if __name__ == "__main__":
    config = ARPFurnitureBenchDataset.get_default_config()
    base_path = "/home/data/num1000_furniture_legacy_preprocessed/one_leg"
    # base_path = "/home/data/furniture_sim_preprocessed/low/one_leg"
    config.data_dir = base_path
    config.window_size = 4
    config.skip_frame = 16
    config.num_demos = 100
    config.image_keys = "color_image2|color_image1"
    config.use_bert_tokenizer = True
    # config.rtg_key = "clip"
    config.use_liv = True
    config.use_nfp = True
    config.use_sparse = False
    config.output_type = "raw"
    # config.output_type = "feature"

    split = "train"
    h5_file_name = os.path.join(base_path, split, f"data_w{config.window_size}_s{config.skip_frame}.hdf5")
    # ds = ARPFurnitureBenchDataset(update=config, h5_file=h5_file)
    ds = ARPFurnitureBenchDataset(update=config, h5_file_name=h5_file_name)
    from tqdm import trange

    # for i in trange(len(ds)):
    for i in trange(390, 400):
        batch = ds[i]
        for key, val in batch.items():
            if isinstance(val, dict):
                for ik in val:
                    print(f"[INFO] {key}|{ik}: {val[ik].shape}")
            elif isinstance(val, np.ndarray):
                print(f"[INFO] {key}: {val.shape}")
            else:
                print(f"[INFO] {key}: {val}")
        print(f"[INFO] reward: {batch['reward']}")

        from PIL import Image

        image_keys = config.image_keys.split("|")
        for ik in image_keys:
            for ws in range(config.window_size):
                img = Image.fromarray(batch["goal_image"][ik][ws])
                img.save(f"{ik}_goal_images_ws{ws}.jpeg")
                img = Image.fromarray(batch["initial_image"][ik][ws])
                img.save(f"{ik}_initial_images_ws{ws}.jpeg")
                img = Image.fromarray(batch["image"][ik][ws])
                img.save(f"{ik}_ws{ws}.jpeg")
                img = Image.fromarray(batch["next_image"][ik][ws])
                img.save(f"{ik}_next_ws{ws}.jpeg")
                img = Image.fromarray(batch["nfp_next_image"][ik][ws])
                img.save(f"{ik}_nfp_next_ws{ws}.jpeg")

        # print(f"[INFO] skill: {batch['skill']}")
        # print(f"[INFO] next_skill: {batch['next_skill']}")
        # print(f"[INFO] rtg of dataset: {ds.rtg}")
        # print(f"[INFO] rtg_scale of dataset: {ds.rtg_scale}")
        # from PIL import Image

        # image_keys = config.image_keys.split("|")
        # for ik in image_keys:
        #     img = Image.fromarray(batch["goal_image"][ik])
        #     img.save(f"{ik}_goal_images.jpeg")
        #     img = Image.fromarray(batch["initial_image"][ik])
        #     img.save(f"{ik}_initial_images.jpeg")
        #     for i in range(config.window_size):
        #         img = Image.fromarray(batch["image"][ik][i])
        #         img.save(f"{ik}_{i}.jpeg")
