import glob
import os
import pickle
from collections import deque
from io import BytesIO
from typing import Sequence

import clip
import numpy as np
import torch
from ml_collections import ConfigDict
from tqdm import tqdm

from .instruct import get_furniturebench_instruct


class ARPFurnitureBenchDataset(torch.utils.data.Dataset):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()

        config.data_dir = ""
        config.env_type = "furnituresim"
        config.max_episode_steps = 500
        config.task_name = "one_leg"
        config.is_sim = True

        config.start_index = 0
        config.max_length = int(1e9)
        config.random_start = False

        config.image_size = 224

        config.image_keys = "color_image2"
        config.image_main_key = "color_image2"
        config.state_key = ""
        config.action_dim = 8
        config.clip_action = 0.999

        config.skip_frame = 16
        config.window_size = 2

        config.use_bert_tokenizer = False
        config.tokenizer_max_length = 77

        config.offset = 0
        config.num_demos = 100
        config.target_skill = -1

        # Reward Learning option
        config.output_type = "feature"
        config.pvr_type = "LIV"
        config.use_sparse = False

        # RTG related configs.
        config.rtg_key = ""
        config.rtg_scale = 1000.0

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, update, start_offset_ratio=None, split="train", demo_type="success", h5_file=None):
        self.config = self.get_default_config(update)
        self.split = split

        self.demo_type = "" if demo_type == "all" else f"_{demo_type}"
        self.h5_file = self.preprocess() if h5_file is None else h5_file

        if self.config.random_start:
            self.random_start_offset = np.random.default_rng().choice(len(self))
        elif start_offset_ratio is not None:
            self.random_start_offset = int(len(self) * start_offset_ratio) % len(self)
        else:
            self.random_start_offset = 0

        self.suffix = "" if self.config.output_type == "raw" else f"_{self.config.pvr_type}"
        if self.suffix == "":
            self.tokenizer = self.build_tokenizer()
        if self.config.rtg_key != "":
            self._rtg_key = f"rtg_{self.config.rtg_key}"
            self._rtg_min, self._rtg_max = np.min(self.h5_file[self._rtg_key]), np.max(self.h5_file[self._rtg_key])
            print(f"[INFO] {self._rtg_key}: {self._rtg_min} ~ {self._rtg_max}")

    def preprocess(self):
        data_dir, task_name = self.config.data_dir.split("|"), self.config.task_name.split("|")
        episodes = []
        num_demos_per_task = self.config.num_demos // len(task_name)
        for dd, tn in zip(data_dir, task_name):
            _dataset_path = dd
            if self.config.is_sim:
                print(f"collect {self.demo_type} demonstrations.")
                _episodes = sorted(glob.glob(os.path.join(_dataset_path, self.split, f"*{self.demo_type}.pkl")))[
                    :num_demos_per_task
                ]
            else:
                _episodes = sorted(glob.glob(os.path.join(_dataset_path, self.split, "2023*.pkl")))[:num_demos_per_task]
            _episodes_pairs = [(tn, _ep) for _ep in _episodes]
            episodes.extend(_episodes_pairs)

        image_keys = self.config.image_keys.split("|")
        target_keys = (
            ["action", "timestep", "skill", "attn_mask", "demo_number", "demo_offset", "demo_idx"]
            + image_keys
            + [f"nfp_next_{ik}" for ik in image_keys]
        )
        ret = {key: [] for key in target_keys}
        done_bool_, task_name_ = [], []

        offset, demo_cnt, demo_number = self.config.offset, 0, 0
        for tn, ep in tqdm(episodes, desc="load data", ncols=0):
            with open(ep, "rb") as f:
                ep = pickle.load(f)
            ep = {key: np.asarray(val) for key, val in ep.items()}
            N = ep["rewards"].shape[0]
            ret["demo_idx"].extend([[demo_cnt, demo_cnt + N]] * N)
            cumsum_skills = np.cumsum(ep["skills"])
            subgoals = np.asarray(list(np.nonzero(ep["skills"])[0]) + [len(ep["skills"]) - 1])
            if self.config.target_skill != -1:
                target_skill = self.config.target_skill
                target_idx = list(np.where(cumsum_skills == target_skill)[0])
                target_idx.append(subgoals[target_skill])
                if target_skill != 0:
                    target_idx = target_idx[1:]
                ep = {key: val[target_idx] for key, val in ep.items() if val.ndim > 0}
                step_rewards = cumsum_skills[target_idx]
                cumsum_skills = ep["skills"]
                subgoals = [len(ep["skills"]) - 1, len(ep["skills"]) - 1]
            else:
                step_rewards = cumsum_skills
                subgoals = list(np.nonzero(ep["skills"])[0]) + [len(ep["skills"]) - 1]

            N = ep["rewards"].shape[0]
            ret["demo_idx"].extend([[demo_cnt, demo_cnt + N]] * N)
            # subgoals = list(np.nonzero(ep["skills"])[0]) + [len(ep["skills"]) - 1]
            subgoals = [elem + demo_cnt for elem in subgoals]

            for i in range(offset, N):
                action = ep["actions"][i].astype(np.float32)
                timestep = np.asarray(i).astype(np.int32)
                attn_mask = np.asarray(1).astype(np.int32)
                skill = step_rewards[i].astype(np.int32)

                ret["action"].append(action)
                ret["timestep"].append(timestep)
                ret["attn_mask"].append(attn_mask)
                ret["skill"].append(skill)
                ret["demo_offset"].append(i)
                ret["demo_number"].append(demo_number)

                done_bool = bool(i == N - 1)
                done_bool_.append(done_bool)
                encoded_tn = np.frombuffer(str(tn).encode("utf-8"), dtype=np.uint8)
                task_name_.append(encoded_tn)

                for ik in image_keys:
                    # image = ep["observations"][i][ik].astype(np.uint8)
                    next_image = ep["observations"][min(i + 1, N - 1)][ik].astype(np.uint8)
                    ret[ik].append(next_image)
                    ret[f"nfp_next_{ik}"].append(demo_cnt + min(i + self.config.skip_frame, N - 1))
            demo_cnt += N
            demo_number += 1

        ret = {key: np.asarray(val) for key, val in ret.items()}
        ret.update({"terminals": np.asarray(done_bool_), "task_name": np.asarray(task_name_)})
        return ret

    def __getstate__(self):
        return self.config, self.random_start_offset, self.h5_file

    def __setstate__(self, state):
        config, random_start_offset, h5_file = state
        self.__init__(config, h5_file=h5_file)
        self.random_start_offset = random_start_offset

    def __len__(self):
        demo_indicators = np.unique(self.h5_file["demo_idx"])
        return min(demo_indicators[self.config.num_demos] - self.config.start_index, self.config.max_length)

    def process_index(self, index):
        index = (index + self.random_start_offset) % len(self)
        return index + self.config.start_index

    def build_tokenizer(self):
        use_bert_tokenizer = self.config.use_bert_tokenizer
        tokenizer_max_length = self.config.tokenizer_max_length

        if use_bert_tokenizer:
            import transformers

            tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
        else:
            tokenizer = clip.tokenize

        def tokenizer_fn(instruct):
            if use_bert_tokenizer:
                if len(instruct) == 0:
                    tokenized_instruct = np.zeros(tokenizer_max_length, dtype=np.int32)
                    padding_mask = np.ones(tokenizer_max_length, dtype=np.float32)
                else:
                    encoded_instruct = tokenizer(
                        instruct,
                        padding="max_length",
                        truncation=True,
                        max_length=tokenizer_max_length,
                        return_tensors="np",
                        add_special_tokens=False,
                    )
                    tokenized_instruct = encoded_instruct["input_ids"].astype(np.int32)
                    padding_mask = 1.0 - encoded_instruct["attention_mask"].astype(np.float32)
            else:
                tokenized_instruct = np.asarray(tokenizer(instruct)).astype(np.int32)
                padding_mask = np.ones(tokenizer_max_length).astype(np.float32)
            return tokenized_instruct, padding_mask

        return tokenizer_fn

    def process_rtg(self, rtg):
        return (rtg - self._rtg_min) / self.config.rtg_scale

    def compute_window_indices(self, index):
        window_size, frame = self.config.window_size, self.config.skip_frame
        (demo_start, _), demo_offset = self.h5_file["demo_idx"][index], self.h5_file["demo_offset"][index]
        stack_start, stack_end = (
            max(demo_start + demo_offset - window_size * frame, demo_start),
            demo_start + demo_offset,
        )
        batch_indices = [stack_end]
        e = stack_end
        for _ in range(self.config.window_size - 1):
            e -= frame
            if e >= stack_start:
                batch_indices.insert(0, e)
            else:
                break
        batch_indices = np.asarray(batch_indices)
        return batch_indices

    def get_tokenized_instructions(self, task_name: bytes, skills: Sequence):
        with BytesIO(task_name) as fin:
            decoded_task_name = fin.read().decode("utf-8")
        _instruct = []
        for idx in range(self.config.window_size):
            _instruct.append(get_furniturebench_instruct(decoded_task_name, np.asarray(skills[idx])))
        instruct, text_padding_mask = self.tokenizer(_instruct)
        return instruct, text_padding_mask

    def __getitem__(self, index):
        index = self.process_index(index)
        batch = {"image": {}}
        image_keys = self.config.image_keys.split("|")

        batch_indices = self.compute_window_indices(index)

        batch.update({"nfp_next_image": {}})
        image_stack = {ik: deque([], maxlen=self.config.window_size) for ik in image_keys}
        for _ in range(self.config.window_size):
            for ik in image_keys:
                image_stack[ik].append(np.zeros((self.config.image_size, self.config.image_size, 3), dtype=np.uint8))
        for ik in image_keys:
            images = self.h5_file[ik][self.h5_file[f"nfp_next_{ik}"][batch_indices]]
            if images.shape[1] == 3:
                images = images.transpose(0, 2, 3, 1)
            image_stack[ik].extend(images)
        batch["nfp_next_image"] = {key: np.asarray(val) for key, val in image_stack.items()}

        stack = {key: deque([], maxlen=self.config.window_size) for key in ["action", "timestep", "skill", "attn_mask"]}
        for _ in range(self.config.window_size):
            stack["action"].append(np.zeros((self.config.action_dim,), dtype=np.float32))
            stack["timestep"].append(np.asarray(0).astype(np.int32))
            stack["skill"].append(np.asarray(0).astype(np.int32))
            stack["attn_mask"].append(np.asarray(0).astype(np.int32))

        for key in stack:
            stack[key].extend(self.h5_file[key][batch_indices])

        # apply rtgs
        if self.config.rtg_key != "":
            stack["rtg"] = deque([], maxlen=self.config.window_size)
            for _ in range(self.config.window_size):
                stack["rtg"].append(np.asarray([0]).astype(np.int32))
            stack["rtg"].extend(
                np.vectorize(self.process_rtg)(self.h5_file[f"rtg_{self.config.rtg_key}"][batch_indices][..., None])
            )

        stack = {key: np.asarray(val) for key, val in stack.items()}
        for key in stack:
            batch[key] = stack[key]

        # build image stack.
        image_stack = {ik: deque([], maxlen=self.config.window_size) for ik in image_keys}
        for _ in range(self.config.window_size):
            for ik in image_keys:
                image_stack[ik].append(np.zeros((self.config.image_size, self.config.image_size, 3), dtype=np.uint8))

        for ik in image_keys:
            images = self.h5_file[ik][batch_indices]
            if images.shape[1] == 3:
                images = images.transpose(0, 2, 3, 1)
            image_stack[ik].extend(images)

        batch["image"] = {key: np.asarray(val) for key, val in image_stack.items()}

        # clip action for stabilizing.
        batch["action"] = np.clip(batch["action"], -self.config.clip_action, self.config.clip_action)

        _instruct = []
        for idx in range(self.config.window_size):
            with BytesIO(self.h5_file["task_name"][index]) as fin:
                decoded_task_name = fin.read().decode("utf-8")
            _instruct.append(get_furniturebench_instruct(decoded_task_name, np.asarray(stack["skill"][idx])))
        batch["instruct"], batch["text_padding_mask"] = self.tokenizer(_instruct)

        if self.suffix == "":
            batch["instruct"], _ = self.get_tokenized_instructions(self.h5_file["task_name"][index], stack["skill"])
        else:
            batch["instruct"] = self.h5_file[f"instruct{self.suffix}"][index]

        # GT reward
        if self.config.use_sparse:
            batch["reward"] = self.h5_file["terminals"][batch_indices].astype(np.float32)
        else:
            batch["reward"] = batch["skill"].astype(np.float32)

        return batch

    @property
    def num_actions(self):
        return self.config.action_dim

    @property
    def rtg(self):
        if self.config.rtg_key != "":
            return np.quantile(self.h5_file[f"rtg_{self.config.rtg_key}"], 0.9)
        else:
            return None

    @property
    def rtg_scale(self):
        if self.config.rtg_key != "":
            return self.config.rtg_scale
        else:
            return None

    @property
    def obs_shape(self):
        res = {"image": {}}
        for key in self.config.image_keys.split("|"):
            res["image"][key] = (self.config.image_size, self.config.image_size, 3)
        res["rtg"] = (1,)
        if self.config.state_key != "":
            res["state"] = self.config.state_dim
        return res


if __name__ == "__main__":
    config = ARPFurnitureBenchDataset.get_default_config()
    base_path = "/home/data/furniture_sim_preprocessed/low/one_leg"
    config.task_name = "one_leg"
    config.data_dir = base_path
    config.window_size = 4
    config.skip_frame = 16
    config.num_demos = 10
    config.image_keys = "color_image2|color_image1"
    config.use_bert_tokenizer = False
    config.is_sim = True
    config.use_sparse = False
    config.output_type = "raw"
    config.target_skill = -1
    # config.rtg_key = "clip"

    split = "train"
    ds = ARPFurnitureBenchDataset(update=config)
    from tqdm import trange

    # for i in trange(len(ds)):
    for i in trange(300, 400):
        batch = ds[i]
        for key, val in batch.items():
            if isinstance(val, dict):
                for ik in val:
                    print(f"[INFO] {key}|{ik}: {val[ik].shape}")
            elif isinstance(val, np.ndarray):
                print(f"[INFO] {key}: {val.shape}")
            else:
                print(f"[INFO] {key}: {val}")

        # # print(f"[INFO] rtg: {batch['rtg']}")
        # # print(f"[INFO] rtg of dataset: {ds.rtg}")
        # # print(f"[INFO] rtg_scale of dataset: {ds.rtg_scale}")
        from PIL import Image

        image_keys = config.image_keys.split("|")
        for ik in image_keys:
            for ws in range(config.window_size):
                img = Image.fromarray(batch["image"][ik][ws])
                img.save(f"{ik}_ws{ws}.jpeg")
                img = Image.fromarray(batch["nfp_next_image"][ik][ws])
                img.save(f"{ik}_nfp_next_ws{ws}.jpeg")
