import os
import pickle

import numpy as np
from tqdm import tqdm


def get_goal(name):
    if "large" in name:
        return (32.0, 24.0)
    elif "medium" in name:
        return (20.0, 20.0)
    elif "umaze" in name:
        return (0.0, 8.0)
    return None


def new_get_trj_idx(env, terminate_on_end=False, **kwargs):
    if not hasattr(env, "get_dataset"):
        dataset = kwargs["dataset"]
    else:
        dataset = env.get_dataset()
    N = dataset["rewards"].shape[0]

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatibility.
    use_timeouts = False
    if "timeouts" in dataset:
        use_timeouts = True

    episode_step = 0
    start_idx, data_idx = 0, 0
    trj_idx_list = []
    for i in range(N - 1):
        if hasattr(env, "spec") and "maze" in env.spec.id:
            done_bool = sum(dataset["infos/goal"][i + 1] - dataset["infos/goal"][i]) > 0
        else:
            done_bool = bool(dataset["terminals"][i])
        if use_timeouts:
            final_timestep = dataset["timeouts"][i]
        else:
            final_timestep = episode_step == env._max_episode_steps - 1
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            trj_idx_list.append([start_idx, data_idx - 1])
            start_idx = data_idx
            continue
        if done_bool or final_timestep:
            episode_step = 0
            trj_idx_list.append([start_idx, data_idx])
            start_idx = data_idx + 1

        episode_step += 1
        data_idx += 1

    trj_idx_list.append([start_idx, data_idx])

    return trj_idx_list


def create_queries(env, dataset, num_query, len_query, data_dir=None, balance=False, label_type=0, skip_flag=0):
    os.makedirs(data_dir, exist_ok=True)
    trj_idx_list = new_get_trj_idx(env, dataset=dataset)  # get_nonmdp_trj_idx(env)
    labeler_info = np.zeros(len(trj_idx_list) - 1)

    # to-do: parallel implementation
    trj_idx_list = np.array(trj_idx_list)
    trj_len_list = trj_idx_list[:, 1] - trj_idx_list[:, 0] + 1

    assert max(trj_len_list) > len_query

    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))

    observation_dim = dataset["observations"].shape[-1]
    total_obs_seq_1, total_obs_seq_2 = (
        np.zeros((num_query, len_query, observation_dim)),
        np.zeros((num_query, len_query, observation_dim)),
    )
    total_next_obs_seq_1, total_next_obs_seq_2 = (
        np.zeros((num_query, len_query, observation_dim)),
        np.zeros((num_query, len_query, observation_dim)),
    )

    action_dim = dataset["actions"].shape[-1]
    total_act_seq_1, total_act_seq_2 = (
        np.zeros((num_query, len_query, action_dim)),
        np.zeros((num_query, len_query, action_dim)),
    )

    total_timestep_1, total_timestep_2 = (
        np.zeros((num_query, len_query), dtype=np.int32),
        np.zeros((num_query, len_query), dtype=np.int32),
    )

    use_image = dataset.get("images") is not None
    if use_image:
        image_shape = dataset["images"][0].shape
        total_images_1, total_images_2 = (
            np.zeros((num_query, len_query, *image_shape), dtype=np.uint8),
            np.zeros((num_query, len_query, *image_shape), dtype=np.uint8),
        )
        total_next_images_1, total_next_images_2 = (
            np.zeros((num_query, len_query, *image_shape), dtype=np.uint8),
            np.zeros((num_query, len_query, *image_shape), dtype=np.uint8),
        )

    start_indices_1, start_indices_2 = np.zeros(num_query), np.zeros(num_query)
    time_indices_1, time_indices_2 = np.zeros(num_query), np.zeros(num_query)

    indices_1_filename = os.path.join(data_dir, f"indices_num{num_query}_q{len_query}")
    indices_2_filename = os.path.join(data_dir, f"indices_2_num{num_query}_q{len_query}")
    label_dummy_filename = os.path.join(data_dir, "label_dummy")

    if not os.path.exists(indices_1_filename) or not os.path.exists(indices_2_filename):
        for query_count in tqdm(range(num_query), desc="get queries"):
            temp_count = 0
            labeler = -1
            while temp_count < 2:
                trj_idx = np.random.choice(np.arange(len(trj_idx_list) - 1)[np.logical_not(labeler_info)])
                len_trj = trj_len_list[trj_idx]

                if len_trj > len_query and (temp_count == 0 or labeler_info[trj_idx] == labeler):
                    labeler = labeler_info[trj_idx]
                    time_idx = np.random.choice(len_trj - len_query + 1)
                    start_idx = trj_idx_list[trj_idx][0] + time_idx
                    end_idx = start_idx + len_query

                    assert end_idx <= trj_idx_list[trj_idx][1] + 1

                    reward_seq = dataset["rewards"][start_idx:end_idx]
                    obs_seq = dataset["observations"][start_idx:end_idx]
                    next_obs_seq = dataset["next_observations"][start_idx:end_idx]
                    act_seq = dataset["actions"][start_idx:end_idx]
                    timestep_seq = np.arange(time_idx + 1, time_idx + len_query + 1)
                    # timestep_seq = np.arange(1, len_query + 1)
                    if use_image:
                        image_seq = dataset["images"][start_idx:end_idx]
                        next_image_seq = dataset["next_images"][start_idx:end_idx]

                    # skip flag 1: skip queries with equal rewards.
                    if skip_flag == 1 and temp_count == 1:
                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                            continue
                    # skip flag 2: keep queries with equal reward until 50% of num_query.
                    if skip_flag == 2 and temp_count == 1 and query_count < int(0.5 * num_query):
                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                            continue
                    # skip flag 3: keep queries with equal reward until 20% of num_query.
                    if skip_flag == 3 and temp_count == 1 and query_count < int(0.2 * num_query):
                        if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                            continue

                    if temp_count == 0:
                        start_indices_1[query_count] = start_idx
                        time_indices_1[query_count] = time_idx
                        total_reward_seq_1[query_count] = reward_seq
                        total_obs_seq_1[query_count] = obs_seq
                        total_next_obs_seq_1[query_count] = next_obs_seq
                        total_act_seq_1[query_count] = act_seq
                        total_timestep_1[query_count] = timestep_seq
                        if use_image:
                            total_images_1[query_count] = image_seq
                            total_next_images_1[query_count] = next_image_seq
                    else:
                        start_indices_2[query_count] = start_idx
                        time_indices_2[query_count] = time_idx
                        total_reward_seq_2[query_count] = reward_seq
                        total_obs_seq_2[query_count] = obs_seq
                        total_next_obs_seq_2[query_count] = next_obs_seq
                        total_act_seq_2[query_count] = act_seq
                        total_timestep_2[query_count] = timestep_seq
                        if use_image:
                            total_images_2[query_count] = image_seq
                            total_next_images_2[query_count] = next_image_seq

                    temp_count += 1

        seg_reward_1 = total_reward_seq_1.copy()
        seg_reward_2 = total_reward_seq_2.copy()

        seg_obs_1 = total_obs_seq_1.copy()
        seg_obs_2 = total_obs_seq_2.copy()

        seg_next_obs_1 = total_next_obs_seq_1.copy()
        seg_next_obs_2 = total_next_obs_seq_2.copy()

        seq_act_1 = total_act_seq_1.copy()
        seq_act_2 = total_act_seq_2.copy()

        seq_timestep_1 = total_timestep_1.copy()
        seq_timestep_2 = total_timestep_2.copy()

        if use_image:
            seq_image_1 = total_images_1.copy()
            seq_image_2 = total_images_2.copy()
            seq_next_image_1 = total_next_images_1.copy()
            seq_next_image_2 = total_next_images_2.copy()

        if label_type == 0:  # perfectly rational
            sum_r_t_1 = np.sum(seg_reward_1, axis=1)
            sum_r_t_2 = np.sum(seg_reward_2, axis=1)
            binary_label = 1 * (sum_r_t_1 < sum_r_t_2)
            rational_labels = np.zeros((len(binary_label), 2))
            rational_labels[np.arange(binary_label.size), binary_label] = 1.0
        elif label_type == 1:
            sum_r_t_1 = np.sum(seg_reward_1, axis=1)
            sum_r_t_2 = np.sum(seg_reward_2, axis=1)
            binary_label = 1 * (sum_r_t_1 < sum_r_t_2)
            rational_labels = np.zeros((len(binary_label), 2))
            rational_labels[np.arange(binary_label.size), binary_label] = 1.0
            margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1)
            rational_labels[margin_index] = 0.5

        start_indices_1 = np.array(start_indices_1, dtype=np.int32)
        start_indices_2 = np.array(start_indices_2, dtype=np.int32)
        time_indices_1 = np.array(time_indices_1, dtype=np.int32)
        time_indices_2 = np.array(time_indices_2, dtype=np.int32)

        batch = {}
        batch["labels"] = rational_labels
        batch["observations"] = seg_obs_1  # for compatibility, remove "_1"
        batch["next_observations"] = seg_next_obs_1
        batch["actions"] = seq_act_1
        batch["observations_2"] = seg_obs_2
        batch["next_observations_2"] = seg_next_obs_2
        batch["actions_2"] = seq_act_2
        batch["timestep"] = seq_timestep_1
        batch["timestep_2"] = seq_timestep_2
        batch["start_indices"] = start_indices_1
        batch["start_indices_2"] = start_indices_2
        if use_image:
            batch["images"] = seq_image_1
            batch["images_2"] = seq_image_2
            batch["next_images"] = seq_next_image_1
            batch["next_images_2"] = seq_next_image_2

        # balancing data with zero_labels
        if balance:
            nonzero_condition = np.any(batch["labels"] != [0.5, 0.5], axis=1)
            (nonzero_idx,) = np.where(nonzero_condition)
            (zero_idx,) = np.where(np.logical_not(nonzero_condition))
            selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx))
            for key, val in batch.items():
                batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])]
            print(f"size of batch after balancing: {len(batch['labels'])}")

        with open(indices_1_filename, "wb") as fp, open(indices_2_filename, "wb") as gp, open(
            label_dummy_filename, "wb"
        ) as hp:
            pickle.dump(batch["start_indices"], fp)
            pickle.dump(batch["start_indices_2"], gp)
            pickle.dump(np.ones_like(batch["labels"]), hp)
    else:
        with open(indices_1_filename, "rb") as fp, open(indices_2_filename, "rb") as gp:
            indices_1, indices_2 = pickle.load(fp), pickle.load(gp)

        return load_queries_with_indices(
            env,
            dataset,
            num_query,
            len_query,
            label_type=label_type,
            saved_indices=[indices_1, indices_2],
            saved_labels=None,
            balance=balance,
            scripted_teacher=True,
        )

    return batch


def find_time_idx(trj_idx_list, idx):
    for start, end in trj_idx_list:
        if start <= idx <= end:
            return idx - start


def load_queries_with_indices(
    env, dataset, num_query, len_query, label_type, saved_indices, saved_labels, balance=False, scripted_teacher=False
):
    trj_idx_list = new_get_trj_idx(env, dataset=dataset)  # get_nonmdp_trj_idx(env)

    # to-do: parallel implementation
    trj_idx_list = np.array(trj_idx_list)
    trj_len_list = trj_idx_list[:, 1] - trj_idx_list[:, 0] + 1

    assert max(trj_len_list) > len_query

    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))

    observation_dim = dataset["observations"].shape[-1]
    action_dim = dataset["actions"].shape[-1]

    total_obs_seq_1, total_obs_seq_2 = (
        np.zeros((num_query, len_query, observation_dim)),
        np.zeros((num_query, len_query, observation_dim)),
    )
    total_next_obs_seq_1, total_next_obs_seq_2 = (
        np.zeros((num_query, len_query, observation_dim)),
        np.zeros((num_query, len_query, observation_dim)),
    )
    total_act_seq_1, total_act_seq_2 = (
        np.zeros((num_query, len_query, action_dim)),
        np.zeros((num_query, len_query, action_dim)),
    )
    total_timestep_1, total_timestep_2 = (
        np.zeros((num_query, len_query), dtype=np.int32),
        np.zeros((num_query, len_query), dtype=np.int32),
    )
    use_image = dataset.get("images") is not None
    if use_image:
        image_shape = dataset["images"][0].shape
        total_images_1, total_images_2 = (
            np.zeros((num_query, len_query, *image_shape), dtype=np.uint8),
            np.zeros((num_query, len_query, *image_shape), dtype=np.uint8),
        )
        total_next_images_1, total_next_images_2 = (
            np.zeros((num_query, len_query, *image_shape), dtype=np.uint8),
            np.zeros((num_query, len_query, *image_shape), dtype=np.uint8),
        )

    if saved_labels is None:
        query_range = np.arange(num_query)
    else:
        query_range = np.arange(len(saved_labels) - num_query, len(saved_labels))

    for query_count, i in enumerate(tqdm(query_range, desc="get queries from saved indices")):
        temp_count = 0
        while temp_count < 2:
            start_idx = saved_indices[temp_count][i]
            end_idx = start_idx + len_query

            reward_seq = dataset["rewards"][start_idx:end_idx]
            obs_seq = dataset["observations"][start_idx:end_idx]
            next_obs_seq = dataset["next_observations"][start_idx:end_idx]
            act_seq = dataset["actions"][start_idx:end_idx]
            time_idx = find_time_idx(trj_idx_list, start_idx)
            timestep_seq = np.arange(time_idx + 1, time_idx + len_query + 1)
            # timestep_seq = np.arange(1, len_query + 1)
            if use_image:
                image_seq = dataset["images"][start_idx:end_idx]
                next_image_seq = dataset["next_images"][start_idx:end_idx]

            if temp_count == 0:
                total_reward_seq_1[query_count] = reward_seq
                total_obs_seq_1[query_count] = obs_seq
                total_next_obs_seq_1[query_count] = next_obs_seq
                total_act_seq_1[query_count] = act_seq
                total_timestep_1[query_count] = timestep_seq
                if use_image:
                    total_images_1[query_count] = image_seq
                    total_next_images_1[query_count] = next_image_seq
            else:
                total_reward_seq_2[query_count] = reward_seq
                total_obs_seq_2[query_count] = obs_seq
                total_next_obs_seq_2[query_count] = next_obs_seq
                total_act_seq_2[query_count] = act_seq
                total_timestep_2[query_count] = timestep_seq
                if use_image:
                    total_images_2[query_count] = image_seq
                    total_next_images_2[query_count] = next_image_seq

            temp_count += 1

    seg_reward_1 = total_reward_seq_1.copy()
    seg_reward_2 = total_reward_seq_2.copy()

    seg_obs_1 = total_obs_seq_1.copy()
    seg_obs_2 = total_obs_seq_2.copy()

    seg_next_obs_1 = total_next_obs_seq_1.copy()
    seg_next_obs_2 = total_next_obs_seq_2.copy()

    seq_act_1 = total_act_seq_1.copy()
    seq_act_2 = total_act_seq_2.copy()

    seq_timestep_1 = total_timestep_1.copy()
    seq_timestep_2 = total_timestep_2.copy()

    if use_image:
        seq_image_1 = total_images_1.copy()
        seq_image_2 = total_images_2.copy()
        seq_next_image_1 = total_next_images_1.copy()
        seq_next_image_2 = total_next_images_2.copy()

    if label_type == 0:  # perfectly rational
        sum_r_t_1 = np.sum(seg_reward_1, axis=1)
        sum_r_t_2 = np.sum(seg_reward_2, axis=1)
        binary_label = 1 * (sum_r_t_1 < sum_r_t_2)
        rational_labels = np.zeros((len(binary_label), 2))
        rational_labels[np.arange(binary_label.size), binary_label] = 1.0
    elif label_type == 1:
        sum_r_t_1 = np.sum(seg_reward_1, axis=1)
        sum_r_t_2 = np.sum(seg_reward_2, axis=1)
        binary_label = 1 * (sum_r_t_1 < sum_r_t_2)
        rational_labels = np.zeros((len(binary_label), 2))
        rational_labels[np.arange(binary_label.size), binary_label] = 1.0
        margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= 0).reshape(-1)
        rational_labels[margin_index] = 0.5

    batch = {}
    if scripted_teacher:
        # counter part of human label for comparing with human label.
        batch["labels"] = rational_labels
    else:
        human_labels = np.zeros((len(saved_labels), 2))
        human_labels[np.array(saved_labels) == 0, 0] = 1.0
        human_labels[np.array(saved_labels) == 1, 1] = 1.0
        human_labels[np.array(saved_labels) == -1] = 0.5
        human_labels = human_labels[query_range]
        batch["labels"] = human_labels
    batch["script_labels"] = rational_labels

    batch["observations"] = seg_obs_1  # for compatibility, remove "_1"
    batch["next_observations"] = seg_next_obs_1
    batch["actions"] = seq_act_1
    batch["observations_2"] = seg_obs_2
    batch["next_observations_2"] = seg_next_obs_2
    batch["actions_2"] = seq_act_2
    batch["timestep"] = seq_timestep_1
    batch["timestep_2"] = seq_timestep_2
    batch["start_indices"] = saved_indices[0]
    batch["start_indices_2"] = saved_indices[1]
    if use_image:
        batch["images"] = seq_image_1
        batch["images_2"] = seq_image_2
        batch["next_images"] = seq_next_image_1
        batch["next_images_2"] = seq_next_image_2

    if balance:
        nonzero_condition = np.any(batch["labels"] != [0.5, 0.5], axis=1)
        (nonzero_idx,) = np.where(nonzero_condition)
        (zero_idx,) = np.where(np.logical_not(nonzero_condition))
        selected_zero_idx = np.random.choice(zero_idx, len(nonzero_idx))
        for key, val in batch.items():
            batch[key] = val[np.concatenate([selected_zero_idx, nonzero_idx])]
        print(f"size of batch after balancing: {len(batch['labels'])}")

    return batch
