import os
import pickle

import h5py
import numpy as np
import torch
from ml_collections import ConfigDict
from tqdm import trange


class PrefDataset(torch.utils.data.Dataset):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()

        config.data_dir = ""
        config.env_type = "factorworld"
        config.env_name = "pick-place-v2"

        config.start_index = 0
        config.max_length = int(1e9)
        config.random_start = False

        config.use_image = False
        config.image_size = 224

        config.image_key = "corner2"
        config.action_dim = 4
        config.clip_action = 0.999

        config.skip_frame = 1

        config.use_human_label = False
        config.num_query = 1000
        config.query_len = 25
        config.skip_flag = 0
        config.label_type = 0

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, update, env, ds=None, query_1=None, query_2=None, label=None, start_offset_ratio=None):
        self.env = env
        self.config = self.get_default_config(update)

        if ds is not None:
            self.h5_file = ds
        else:
            if self.config.env_type in ["factorworld", "metaworld", "robosuite"]:
                self.h5_file = h5py.File(os.path.join(self.config.data_dir, "episodes", "data.hdf5"), "r")
            else:
                self.h5_file = env.get_dataset()

        os.makedirs(os.path.join(self.config.data_dir, "queries"), exist_ok=True)

        if self.config.random_start:
            self.random_start_offset = np.random.default_rng().choice(len(self))
        elif start_offset_ratio is not None:
            self.random_start_offset = int(len(self) * start_offset_ratio) % len(self)
        else:
            self.random_start_offset = 0

        self.trj_idx_list = np.asarray(self.get_trj_idx())
        self.trj_len_list = self.trj_idx_list[:, 1] - self.trj_idx_list[:, 0] + 1

        if query_1 is None or query_2 is None or label is None:
            self.query_1, self.query_2, self.label = self.set_query()
        else:
            self.query_1 = query_1
            self.query_2 = query_2
            if self.config.use_human_label:
                human_labels = np.zeros((len(label), 2))
                human_labels[np.array(label) == 0, 0] = 1.0
                human_labels[np.array(label) == 1, 1] = 1.0
                human_labels[np.array(label) == -1] = 0.5
                self.label = human_labels
            else:
                self.label = label

    def get_trj_idx(self, terminate_on_end=False):
        N = self.h5_file["rewards"].shape[0]

        # The newer version of the self.h5_file adds an explicit
        # timeouts field. Keep old method for backwards compatibility.
        use_timeouts = False
        if "timeouts" in self.h5_file:
            use_timeouts = True

        episode_step = 0
        start_idx, data_idx = 0, 0
        trj_idx_list = []
        for i in range(N - 1):
            if hasattr(self.env, "spec") and "maze" in self.env.spec.id:
                done_bool = sum(self.h5_file["infos/goal"][i + 1] - self.h5_file["infos/goal"][i]) > 0
            else:
                done_bool = bool(self.h5_file["terminals"][i])
            if use_timeouts:
                final_timestep = self.h5_file["timeouts"][i]
            else:
                final_timestep = episode_step == self.env._max_episode_steps - 1
            if (not terminate_on_end) and final_timestep:
                # Skip this transition and don't apply terminals on the last step of an episode
                episode_step = 0
                trj_idx_list.append([start_idx, data_idx - 1])
                start_idx = data_idx
                continue
            if done_bool or final_timestep:
                episode_step = 0
                trj_idx_list.append([start_idx, data_idx])
                start_idx = data_idx + 1

            episode_step += 1
            data_idx += 1

        trj_idx_list.append([start_idx, data_idx])

        return trj_idx_list

    def get_time_idx(self, idx):
        for start, end in self.trj_idx_list:
            if start <= idx <= end:
                return idx - start

    def set_query(self):
        trj_idx_list, trj_len_list, skip_flag = self.trj_idx_list, self.trj_len_list, self.config.skip_flag
        num_query, query_len = self.config.num_query, self.config.query_len
        print(f"Create new queries with {num_query} queries / query length {query_len}")
        label_type = self.config.label_type

        total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, query_len)), np.zeros((num_query, query_len))
        total_start_indices_1, total_start_indices_2 = (
            np.zeros(num_query, dtype=np.int32),
            np.zeros(num_query, dtype=np.int32),
        )
        assert max(trj_len_list) > query_len

        data_dir = self.config.data_dir

        indices_1_filename = os.path.join(data_dir, "queries", f"indices_num{num_query}_q{query_len}")
        indices_2_filename = os.path.join(data_dir, "queries", f"indices_2_num{num_query}_q{query_len}")
        label_dummy_filename = os.path.join(data_dir, "queries", f"label_scripted_num{num_query}_q{query_len}")

        for query_count in trange(num_query, desc="create queries"):
            temp_count = 0
            labeler = -1
            while temp_count < 2:
                trj_idx = np.random.choice(np.arange(len(self.trj_idx_list) - 1))
                len_trj = self.trj_len_list[trj_idx]

                if len_trj > query_len and (temp_count == 0 or trj_idx != labeler):
                    # labeler_info[trj_idx] = labeler
                    # labeler = labeler_info[trj_idx]
                    labeler = trj_idx
                    time_idx = np.random.choice(len_trj - self.config.query_len + 1)
                    start_idx = trj_idx_list[trj_idx][0] + time_idx
                    end_idx = start_idx + query_len
                    reward_seq = self.h5_file["rewards"][start_idx:end_idx]

                    assert end_idx <= trj_idx_list[trj_idx][1] + 1
                    # skip flag 1: skip queries with equal rewards.
                    if skip_flag == 1 and temp_count == 1:
                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                            continue
                    # skip flag 2: keep queries with equal reward until 50% of num_query.
                    if skip_flag == 2 and temp_count == 1 and query_count < int(0.5 * num_query):
                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                            continue
                    # skip flag 3: keep queries with equal reward until 20% of num_query.
                    if skip_flag == 3 and temp_count == 1 and query_count < int(0.2 * num_query):
                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                            continue

                    if temp_count == 0:
                        total_start_indices_1[query_count] = start_idx
                        total_reward_seq_1[query_count] = reward_seq
                    else:
                        total_start_indices_2[query_count] = start_idx
                        total_reward_seq_2[query_count] = reward_seq
                    temp_count += 1

        if label_type == 0:  # perfectly rational
            sum_r_t_1 = np.sum(total_reward_seq_1, axis=1)
            sum_r_t_2 = np.sum(total_reward_seq_2, axis=1)
            binary_label = 1 * (sum_r_t_1 < sum_r_t_2)
            rational_labels = np.zeros((len(binary_label), 2))
            rational_labels[np.arange(binary_label.size), binary_label] = 1.0
        elif label_type == 1:
            sum_r_t_1 = np.sum(total_reward_seq_1, axis=1)
            sum_r_t_2 = np.sum(total_reward_seq_2, axis=1)
            binary_label = 1 * (sum_r_t_1 < sum_r_t_2)
            rational_labels = np.zeros((len(binary_label), 2))
            rational_labels[np.arange(binary_label.size), binary_label] = 1.0
            margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1)
            rational_labels[margin_index] = 0.5

        with open(indices_1_filename, "wb") as fp, open(indices_2_filename, "wb") as gp, open(
            label_dummy_filename, "wb"
        ) as hp:
            pickle.dump(total_start_indices_1, fp)
            pickle.dump(total_start_indices_2, gp)
            pickle.dump(rational_labels, hp)

        return total_start_indices_1, total_start_indices_2, rational_labels

    def __getstate__(self):
        return self.env, self.query_1, self.query_2, self.label, self.config, self.random_start_offset

    def __setstate__(self, state):
        env, query_1, query_2, label, config, random_start_offset = state
        self.__init__(config, env, query_1, query_2, label)
        self.random_start_offset = random_start_offset

    def __len__(self):
        return min(self.config.num_query - self.config.start_index, self.config.max_length)

    def process_index(self, index):
        index = (index + self.random_start_offset) % len(self)
        return index + self.config.start_index

    def __getitem__(self, index):
        index = self.process_index(index)
        batch = {}
        start_idx_1, start_idx_2 = self.query_1[index], self.query_2[index]
        idx = np.random.choice(np.arange(self.config.skip_frame))
        range_1 = np.arange(start_idx_1 + idx, start_idx_1 + idx + self.config.query_len)[:: self.config.skip_frame]
        range_2 = np.arange(start_idx_2 + idx, start_idx_2 + idx + self.config.query_len)[:: self.config.skip_frame]

        # extract target keys used for training reward model.
        target_keys = ["observations", "actions"]
        for key in target_keys:
            batch[key] = self.h5_file[key][range_1]
            batch[f"{key}_2"] = self.h5_file[key][range_2]
        if self.config.use_image:
            batch["images"] = self.h5_file[self.config.image_key][range_1]
            batch["images_2"] = self.h5_file[self.config.image_key][range_2]

        # get timestep for time embedding
        time_idx, time_idx_2 = self.get_time_idx(start_idx_1), self.get_time_idx(start_idx_2)
        batch["timestep"] = np.arange(time_idx + idx, time_idx + idx + self.config.query_len)[:: self.config.skip_frame]
        batch["timestep_2"] = np.arange(time_idx_2 + idx, time_idx_2 + idx + self.config.query_len)[
            :: self.config.skip_frame
        ]

        # use label.
        batch["labels"] = self.label[index]

        # clip action for stabilizing.
        batch["actions"] = np.clip(batch["actions"], -self.config.clip_action, self.config.clip_action)
        batch["actions_2"] = np.clip(batch["actions_2"], -self.config.clip_action, self.config.clip_action)

        return batch


if __name__ == "__main__":
    from bpref_v2.envs import MetaWorld

    env = MetaWorld("pick-place-v2", seed=0)
    config = PrefDataset.get_default_config()
    base_path = "/home/factor-world/data/pick-place-v2"
    config.data_dir = base_path
    config.use_image = True
    # base_path = os.path.join(base_path, "queries")
    # human_indices_2_file, human_indices_1_file, human_labels_file = sorted(os.listdir(base_path))
    # with open(os.path.join(base_path, human_indices_1_file), "rb") as fp:  # Unpickling
    #     human_indices = pickle.load(fp)
    # with open(os.path.join(base_path, human_indices_2_file), "rb") as fp:  # Unpickling
    #     human_indices_2 = pickle.load(fp)
    # with open(os.path.join(base_path, human_labels_file), "rb") as fp:  # Unpickling
    #     human_labels = pickle.load(fp)

    ds = PrefDataset(
        update=config,
        env=env,
        # query_1=human_indices,
        # query_2=human_indices_2,
        # label=human_labels
    )

    batch = ds[0]
    for key, val in batch.items():
        print(f"[INFO] {key}: {val.shape}")
