import random
import traceback
from collections import deque
from io import BytesIO
from pathlib import Path
from typing import Sequence

import clip
import h5py
import numpy as np
import torch
from ml_collections import ConfigDict
from tqdm import trange

from .arp_furniturebench_dataset_inmemory_stream import get_indices_with_same_values
from .instruct import get_maniskill_instruct


def worker_init_fn(worker_id):
    seed = np.random.get_state()[1][0] + worker_id
    np.random.seed(seed)
    random.seed(seed)


def episode_len(episode):
    # subtract -1 because the dummy first transition
    return next(iter(episode.values())).shape[0]


def _process_failure_skills(phase, threshold=20):
    # cumsum_skill = np.cumsum(skill)
    num_phases = int(phase[-1] + 1)
    # split indices per phases
    phase_indices = get_indices_with_same_values(phase)
    # phase_indices: {0: [0, 1, 2, 3, 4, ..., 9], 1: [10, 11, 12, ..., 36], 2: [37, 38, ..., 54] ...}

    failure_skills, failure_rewards = phase.copy(), phase.copy().astype(np.float32)
    # prev_success_skill = 0
    for ph in phase_indices:
        try:
            _phase_indices = phase_indices[ph]
        except Exception:
            raise
        if ph == num_phases - 1:
            failure_skills[_phase_indices] = phase[_phase_indices] + 100
        # if not all(np.equal(cumsum_skill[_phase_indices], phase[_phase_indices])):
        #     skill_for_phase = skill[_phase_indices]
        #     if -1 in skill_for_phase:
        #         idx = np.where(skill_for_phase == -1)[0][0]
        #         if idx < threshold:
        #             # failure_skills[_phase_indices] = 100 + prev_success_skill
        #             failure_skills[phase_indices[max(ph - 1, 0)]] = 100 + max(ph - 1, 0)
        #             failure_skills[_phase_indices] = 100 + ph
        #             failure_rewards[_phase_indices] = prev_success_skill
        #         else:
        #             failure_skills[_phase_indices] = 100 + ph
        #             failure_rewards[_phase_indices] = ph
        #     else:
        #         failure_skills[_phase_indices] = phase[_phase_indices]
        #         failure_rewards[_phase_indices] = phase[_phase_indices]
        #         prev_success_skill = ph
        # else:
        #     prev_success_skill = ph
    return failure_skills, failure_rewards


class ARPManiSkillDataset(torch.utils.data.IterableDataset):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()

        config.data_dir = ""
        config.env_type = "furnituresim"
        config.max_episode_steps = 200
        config.task_name = "PickSingleYCB"
        config.is_sim = True

        config.start_index = 0
        config.max_length = int(1e9)
        config.random_start = False

        config.image_size = 128

        config.image_keys = "hand_camera|base_camera"
        config.image_main_key = "base_camera"
        config.state_key = ""
        config.action_dim = 7
        config.clip_action = 0.999

        config.skip_frame = 16
        config.window_size = 4

        config.use_bert_tokenizer = False
        config.tokenizer_max_length = 77

        config.offset = 0
        config.num_demos = 100
        config.target_skill = -1

        config.fetch_every = 5000
        config.num_workers = 8
        config.max_size = 10000

        # Reward Learning option
        config.output_type = "feature"
        config.pvr_type = "LIV"
        config.use_sparse = False

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, update, start_offset_ratio=None, split="train", demo_type="success"):
        self.config = self.get_default_config(update)
        self.split = split
        self.h5_file = h5py.File(Path(self.config.data_dir) / f"trajectory_{demo_type}.hdf5", "r")

        self.demo_type = "" if demo_type == "all" else demo_type
        self._total_episode_fns = sorted(
            [(self.config.task_name, key) for key in self.h5_file.keys() if "traj" in key],
            key=lambda x: int(x[1].split("_")[-1]),
        )[: self.config.num_demos]
        print(f"[INFO] {self.split}_{self.demo_type} | {len(self._total_episode_fns)} episodes are loaded.")
        self._episode_fns = []
        self._episodes = dict()
        self._size = 0
        self._samples_since_last_fetch = self.config.fetch_every

        self.random_start_offset = 0

        self.suffix = "" if self.config.output_type == "raw" else f"_{self.config.pvr_type}"
        if self.suffix == "":
            self.tokenizer = self.build_tokenizer()

        self.mode = "local"

    def set_mode(self, mode):
        self.mode = mode

    def _load_episode(self, tn, ep):
        image_keys = self.config.image_keys.split("|")
        target_keys = ["action", "timestep", "phase", "attn_mask"] + image_keys + [f"next_{ik}" for ik in image_keys]
        ret = {key: [] for key in target_keys}
        task_name_ = []

        ep = self.h5_file[ep]
        N = ep["dict_str_rewards"].shape[0]
        cumsum_skills = ep["dict_str_rewards"]

        ret["action"] = ep["dict_str_actions"]
        ret["phase"] = cumsum_skills[:].squeeze().astype(np.int32)
        ret["terminals"] = ep["dict_str_episode_dones"][:].squeeze()

        for i in range(N):
            timestep = np.asarray(i).astype(np.int32)
            attn_mask = np.asarray(1).astype(np.int32)

            ret["timestep"].append(timestep)
            ret["attn_mask"].append(attn_mask)
            encoded_tn = np.frombuffer(str(tn).encode("utf-8"), dtype=np.uint8)
            task_name_.append(encoded_tn)

            next_images = np.split(ep["dict_str_obs"]["dict_str_rgb"][min(i + 1, N - 1)], (3,), axis=0)
            for idx, ik in enumerate(image_keys):
                # image = ep["observations"][i][ik].astype(np.uint8)
                ret[ik].append(next_images[idx].astype(np.uint8))
                ret[f"next_{ik}"].append(min(i + 1 + self.config.skip_frame, N - 1))

        ret = {key: np.asarray(val) for key, val in ret.items()}
        ret.update({"task_name": np.asarray(task_name_)})
        return ret

    def __getstate__(self):
        return self.config, self.random_start_offset

    def __setstate__(self, state):
        config, random_start_offset = state
        self.__init__(config)
        self.random_start_offset = random_start_offset

    def build_tokenizer(self):
        use_bert_tokenizer = self.config.use_bert_tokenizer
        tokenizer_max_length = self.config.tokenizer_max_length

        if use_bert_tokenizer:
            import transformers

            tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
        else:
            tokenizer = clip.tokenize

        def tokenizer_fn(instruct):
            if use_bert_tokenizer:
                if len(instruct) == 0:
                    tokenized_instruct = np.zeros(tokenizer_max_length, dtype=np.int32)
                    padding_mask = np.ones(tokenizer_max_length, dtype=np.float32)
                else:
                    encoded_instruct = tokenizer(
                        instruct,
                        padding="max_length",
                        truncation=True,
                        max_length=tokenizer_max_length,
                        return_tensors="np",
                        add_special_tokens=False,
                    )
                    tokenized_instruct = encoded_instruct["input_ids"].astype(np.int32)
                    padding_mask = 1.0 - encoded_instruct["attention_mask"].astype(np.float32)
            else:
                tokenized_instruct = np.asarray(tokenizer(instruct)).astype(np.int32)
                padding_mask = np.ones(tokenizer_max_length).astype(np.float32)
            return tokenized_instruct, padding_mask

        return tokenizer_fn

    def _compute_window_indices(self, demo_offset):
        window_size, frame = self.config.window_size, self.config.skip_frame
        demo_start = 0
        stack_start, stack_end = (
            max(demo_start + demo_offset - window_size * frame, demo_start),
            demo_start + demo_offset,
        )
        batch_indices = [stack_end]
        e = stack_end
        for _ in range(self.config.window_size - 1):
            e -= frame
            if e >= stack_start:
                batch_indices.insert(0, e)
            else:
                break
        batch_indices = np.asarray(batch_indices)
        return batch_indices

    def _get_tokenized_instructions(self, task_name: bytes, skills: Sequence):
        with BytesIO(task_name) as fin:
            decoded_task_name = fin.read().decode("utf-8")
        _instruct = []
        for idx in range(self.config.window_size):
            _instruct.append(get_maniskill_instruct(decoded_task_name, np.asarray(skills[idx])))
        instruct, text_padding_mask = self.tokenizer(_instruct)
        return instruct, text_padding_mask

    def _sample_episode(self):
        eps_fn = random.choice(self._episode_fns)
        return self._episodes[eps_fn]

    def _store_episode(self, tn, eps_fn):
        try:
            episode = self._load_episode(tn, eps_fn)
        except Exception as e:
            print(f"error in loading episode {eps_fn}: {e}")
            return None, False
        ep_len = episode_len(episode)
        while ep_len + self._size > self.config.max_size:
            early_eps_fn = self._episode_fns.pop(0)
            early_eps = self._episodes.pop(early_eps_fn)
            self._size -= episode_len(early_eps)
        self._episode_fns.append(eps_fn)
        self._episode_fns.sort()
        self._episodes[eps_fn] = episode
        self._size += ep_len

        return ep_len, True

    def _try_fetch(self):
        if self._samples_since_last_fetch < self.config.fetch_every:
            return
        self._samples_since_last_fetch = 0
        try:
            worker_id = torch.utils.data.get_worker_info().id
        except AttributeError:
            worker_id = 0
        shuffled_episode_fns = self._total_episode_fns.copy()
        np.random.shuffle(shuffled_episode_fns)
        fetched_size = 0
        for idx in (
            pbar := trange(
                len(shuffled_episode_fns),
                desc=f"[{self.demo_type.upper()} {self.split.upper()} WORKER {worker_id}] fetch data",
                ncols=0,
                leave=False,
            )
        ):
            # print(f"sample episode from {len(self._episode_fns)} episodes.")
            tn, eps_fn = shuffled_episode_fns[idx]
            eps_idx = int(eps_fn.split("_")[-1])
            if eps_idx % self.config.num_workers != worker_id:
                continue
            if eps_fn in self._episodes.keys():
                continue
            if fetched_size > self.config.max_size:
                break
            eps_len, flag = self._store_episode(tn, eps_fn)
            if not flag:
                raise
                break
            pbar.set_postfix({"fetched_size": fetched_size})
            fetched_size += eps_len

    def _sample(self):
        try:
            self._try_fetch()
        except StopIteration:
            traceback.print_exc()
        # print(f"sample episode from {len(self._episode_fns)} episodes.")
        self._samples_since_last_fetch += 1
        episode = self._sample_episode()
        # eps_fn = random.choice(self._episode_fns)
        # print(f"fetch episode from {eps_fn}")
        # episode = self._load_episode(*eps_fn)

        index = np.random.randint(0, len(episode["terminals"]))
        batch = {"image": {}}
        image_keys = self.config.image_keys.split("|")

        batch_indices = self._compute_window_indices(index)

        batch.update({"next_image": {}})
        image_stack = {ik: deque([], maxlen=self.config.window_size) for ik in image_keys}
        for _ in range(self.config.window_size):
            for ik in image_keys:
                image_stack[ik].append(np.zeros((self.config.image_size, self.config.image_size, 3), dtype=np.uint8))
        for ik in image_keys:
            images = episode[ik][episode[f"next_{ik}"][batch_indices]]
            if images.shape[1] == 3:
                images = images.transpose(0, 2, 3, 1)
            image_stack[ik].extend(images)
        batch["next_image"] = {key: np.asarray(val) for key, val in image_stack.items()}

        stack = {key: deque([], maxlen=self.config.window_size) for key in ["action", "timestep", "phase", "attn_mask"]}
        for _ in range(self.config.window_size):
            stack["action"].append(np.zeros((self.config.action_dim,), dtype=np.float32))
            stack["timestep"].append(np.asarray(0).astype(np.int32))
            stack["phase"].append(np.asarray(0).astype(np.int32))
            stack["attn_mask"].append(np.asarray(0).astype(np.int32))

        for key in stack:
            stack[key].extend(episode[key][batch_indices])

        stack = {key: np.asarray(val) for key, val in stack.items()}
        for key in stack:
            batch[key] = stack[key]

        # build image stack.
        image_stack = {ik: deque([], maxlen=self.config.window_size) for ik in image_keys}
        for _ in range(self.config.window_size):
            for ik in image_keys:
                image_stack[ik].append(np.zeros((self.config.image_size, self.config.image_size, 3), dtype=np.uint8))

        image_indices = [min(i + 1, len(episode["terminals"]) - 1) for i in batch_indices]
        for ik in image_keys:
            images = episode[ik][image_indices]
            if images.shape[1] == 3:
                images = images.transpose(0, 2, 3, 1)
            image_stack[ik].extend(images)

        batch["image"] = {key: np.asarray(val) for key, val in image_stack.items()}

        # clip action for stabilizing.
        batch["action"] = np.clip(batch["action"], -self.config.clip_action, self.config.clip_action)

        if self.suffix == "":
            batch["instruct"], _ = self._get_tokenized_instructions(episode["task_name"][index], stack["phase"])
        else:
            batch["instruct"] = episode[f"instruct{self.suffix}"][index]

        # GT reward
        if self.config.use_sparse:
            batch["reward"] = episode["terminals"][batch_indices].astype(np.float32)
        else:
            batch["reward"] = batch["phase"].astype(np.float32)

        if self.demo_type == "failure":
            failure_skills, failure_rewards = _process_failure_skills(episode["phase"])
            keys = ["failure_skill", "failure_reward"]
            for key in keys:
                stack[key] = deque([], maxlen=self.config.window_size)
            for _ in range(self.config.window_size):
                stack["failure_skill"].append(np.asarray(0).astype(np.int32))
                stack["failure_reward"].append(np.asarray(0).astype(np.float32))
            stack["failure_skill"].extend(failure_skills[batch_indices])
            stack["failure_reward"].extend(failure_rewards[batch_indices])

            batch["skill"] = np.asarray(stack["failure_skill"])
            batch["reward"] = np.asarray(stack["failure_reward"])

        return batch

    def __iter__(self):
        while True:
            yield self._sample()

    @property
    def num_actions(self):
        return self.config.action_dim

    @property
    def obs_shape(self):
        res = {"image": {}}
        for key in self.config.image_keys.split("|"):
            res["image"][key] = (self.config.image_size, self.config.image_size, 3)
        res["rtg"] = (1,)
        if self.config.state_key != "":
            res["state"] = self.config.state_dim
        return res


if __name__ == "__main__":
    config = ARPManiSkillDataset.get_default_config()
    base_path = "/home/ManiSkill2/demos_preprocessed/v0/rigid_body/PickSingleYCB-v0"
    config.data_dir = base_path
    config.window_size = 4
    config.skip_frame = 32
    config.num_demos = 10
    config.use_bert_tokenizer = False
    config.is_sim = True
    config.use_sparse = False
    config.output_type = "raw"
    config.num_workers = 1
    config.max_size = 1000
    # config.rtg_key = "clip"

    split = "train"
    ds = ARPManiSkillDataset(update=config, demo_type="failure")
    ds.mode = "global"
    from tqdm import tqdm

    # for i in trange(len(ds)):
    for idx, batch in tqdm(enumerate(ds), total=100):
        if idx == 10:
            print("early break.")
            break
        # batch = ds[i]
        # for key, val in batch.items():
        #     if isinstance(val, dict):
        #         for ik in val:
        #             print(f"[INFO] {key}|{ik}: {val[ik].shape}")
        #     elif isinstance(val, np.ndarray):
        #         print(f"[INFO] {key}: {val.shape}")
        #     else:
        #         print(f"[INFO] {key}: {val}")
        print(f"[INFO] skill: {batch['skill']}")
        # # print(f"[INFO] rtg: {batch['rtg']}")
        # # print(f"[INFO] rtg of dataset: {ds.rtg}")
        # # print(f"[INFO] rtg_scale of dataset: {ds.rtg_scale}")
        from PIL import Image

        image_keys = config.image_keys.split("|")
        for ik in image_keys:
            for ws in range(config.window_size):
                img = Image.fromarray(batch["image"][ik][ws])
                img.save(f"{ik}_ws{ws}.jpeg")
                img = Image.fromarray(batch["next_image"][ik][ws])
                img.save(f"{ik}_next_ws{ws}.jpeg")
                if ds.mode == "local":
                    img = Image.fromarray(batch["initial_image"][ik][ws])
                    img.save(f"{ik}_initial_ws{ws}.jpeg")
                    img = Image.fromarray(batch["goal_image"][ik][ws])
                    img.save(f"{ik}_goal_ws{ws}.jpeg")
