import glob
import os
from collections import defaultdict, deque

import clip
import numpy as np
import torch
from ml_collections import ConfigDict
from tqdm import tqdm

from .instruct import get_factorworld_instruct


class ARPFactorworldDataset(torch.utils.data.Dataset):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()

        config.data_dir = ""
        config.env_type = "factorworld"
        config.max_episode_steps = 500
        config.task_name = "pick-place-v2"

        config.start_index = 0
        config.max_length = int(1e9)
        config.random_start = False

        config.image_size = 224

        config.image_keys = "corner2"
        config.image_main_key = "corner2"
        config.action_dim = 4
        config.clip_action = 0.999

        config.skip_frame = 1
        config.len_subtraj = 2

        # TODO: Only for factor-world
        config.offset = 1

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, update, ds=None, start_offset_ratio=None, split="train"):
        self.split = split
        self.config = self.get_default_config(update)

        if ds is not None:
            self.h5_file = ds
        else:
            self.h5_file = self.preprocess()

        if self.config.random_start:
            self.random_start_offset = np.random.default_rng().choice(len(self))
        elif start_offset_ratio is not None:
            self.random_start_offset = int(len(self) * start_offset_ratio) % len(self)
        else:
            self.random_start_offset = 0

    def preprocess(self):
        data_dir, task_name = self.config.data_dir.split("|"), self.config.task_name.split("|")
        episodes = []
        for dd, tn in zip(data_dir, task_name):
            _dataset_path = os.path.join(dd, self.split)
            _episodes = sorted(glob.glob(os.path.join(_dataset_path, "*.npz")))
            _episodes_pairs = [(tn, _ep) for _ep in _episodes]
            episodes.extend(_episodes_pairs)

        image_keys = self.config.image_keys.split("|")
        target_keys = ["actions", "timesteps", "attn_mask"] + [image_key for image_key in image_keys]
        liv_data = {f"{elem}_{image_key}": [] for elem in ["initial", "next", "goal"] for image_key in image_keys}
        liv_data["r"] = []

        ret = {key: [] for key in target_keys}
        done_bool_, task_name_ = [], []

        cnt, offset = 0, self.config.offset
        for tn, ep in tqdm(episodes, desc="load data", ncols=0):
            ep = np.load(ep, allow_pickle=True)
            ep = {key: ep[key] for key in ep.keys()}
            N = ep["rewards"].shape[0]
            for ik in image_keys:
                liv_data[f"initial_{ik}"].extend([cnt] * (N - offset))
                liv_data[f"goal_{ik}"].extend([cnt + (N - offset) - 1] * (N - offset))

            total_stack = defaultdict(dict)
            for frame in range(self.config.skip_frame):
                total_stack[frame] = {key: deque([], maxlen=self.config.len_subtraj) for key in target_keys}
                for j in range(self.config.len_subtraj):
                    total_stack[frame]["actions"].append(np.zeros((self.config.action_dim,), dtype=np.float32))
                    total_stack[frame]["timesteps"].append(np.asarray(0).astype(np.int32))
                    total_stack[frame]["attn_mask"].append(np.asarray(0).astype(np.int32))
                    for ik in image_keys:
                        total_stack[frame][ik].append(
                            np.zeros((self.config.image_size, self.config.image_size, 3), dtype=np.uint8)
                        )

            for i in range(offset, N):
                mod = i % self.config.skip_frame
                stack = total_stack[mod]

                action = ep["actions"][i].astype(np.float32)
                timestep = np.asarray(i).astype(np.int32)
                attn_mask = np.asarray(1).astype(np.int32)
                stack["actions"].append(action)
                stack["timesteps"].append(timestep)
                stack["attn_mask"].append(attn_mask)

                done_bool = bool(ep["rewards"][i])
                done_bool_.append(done_bool)
                liv_data["r"].append(int(done_bool) - 1)
                task_name_.append(tn)

                for ik in image_keys:
                    image = ep[ik][i].astype(np.uint8)
                    next_image = ep[ik][min(i + 1, N - 1)].astype(np.uint8)
                    stack[ik].append(image)
                    liv_data[f"next_{ik}"].append(next_image)

                for key in target_keys:
                    ret[key].append(np.stack(stack[key]))

                cnt += 1

        ret = {key: np.asarray(val) for key, val in ret.items()}
        ret.update({key: np.asarray(val) for key, val in liv_data.items()})
        ret.update({"terminals": np.asarray(done_bool_), "task_name": np.asarray(task_name_)})
        return ret

    def __getstate__(self):
        return self.config, self.random_start_offset, self.split

    def __setstate__(self, state):
        config, random_start_offset, split = state
        self.__init__(config, split=split)
        self.random_start_offset = random_start_offset

    def __len__(self):
        return min(self.h5_file["actions"].shape[0] - self.config.start_index, self.config.max_length)

    def process_index(self, index):
        index = (index + self.random_start_offset) % len(self)
        return index + self.config.start_index

    def __getitem__(self, index):
        index = self.process_index(index)
        batch = {}
        batch["initial_image"] = self.h5_file[self.config.image_main_key][
            self.h5_file[f"initial_{self.config.image_main_key}"][index]
        ][-1]
        batch["goal_image"] = self.h5_file[self.config.image_main_key][
            self.h5_file[f"goal_{self.config.image_main_key}"][index]
        ][-1]
        batch["r"] = self.h5_file["r"][index]

        batch["images"] = self.h5_file[self.config.image_main_key][index]
        batch["next_image"] = self.h5_file[f"next_{self.config.image_main_key}"][index]

        batch["actions"] = self.h5_file["actions"][index][:-1]
        batch["timesteps"] = self.h5_file["timesteps"][index]
        batch["attn_mask"] = self.h5_file["attn_mask"][index]

        # clip action for stabilizing.
        batch["actions"] = np.clip(batch["actions"], -self.config.clip_action, self.config.clip_action)
        batch["instruct"] = np.asarray(
            clip.tokenize(get_factorworld_instruct(self.h5_file["task_name"][index]))
        ).astype(np.int32)
        batch["instruct"] = np.tile(batch["instruct"], (self.config.len_subtraj, 1))

        return batch


if __name__ == "__main__":
    config = ARPFactorworldDataset.get_default_config()
    base_path = "/home/ablation_data/drawer-close-v2/drawer-close-v2"
    config.data_dir = base_path
    config.len_subtraj = 4
    config.skip_frame = 4
    ds = ARPFactorworldDataset(update=config, split="train")

    batch = ds[15]
    for key, val in batch.items():
        print(f"[INFO] {key}: {val.shape}")

    from PIL import Image

    for i in range(config.len_subtraj):
        img = Image.fromarray(batch["images"][i])
        img.save(f"{config.image_main_key}_{i}.jpeg")
