import os
import pickle

import numpy as np
from tqdm import tqdm
from jaxrl_m.learners.d4rl_utils import new_get_trj_idx

def sample_from_env(env, num_query, len_set, len_query, data_dir):
    assert len_query == 1
    observation_dim = env.reward_observation_space.shape[-1]
    action_dim = env.action_space.shape[-1]
    seg_obs_1 = np.stack(
        [env.reward_observation_space.sample() for _ in range(num_query * len_set)],
        axis=1,
    ).reshape(num_query, len_set, len_query, observation_dim)
    seg_obs_2 = np.stack(
        [env.reward_observation_space.sample() for _ in range(num_query * len_set)],
        axis=1,
    ).reshape(num_query, len_set, len_query, observation_dim)
    seg_act_1 = np.stack(
        [env.action_space.sample() for _ in range(num_query * len_set)], axis=1
    ).reshape(num_query, len_set, len_query, observation_dim)
    seg_act_2 = np.stack(
        [env.action_space.sample() for _ in range(num_query * len_set)], axis=1
    ).reshape(num_query, len_set, len_query, observation_dim)
    labels = np.zeros((num_query, len_set))

    query_path = os.path.join(
        data_dir, f"queries_num{num_query}_q{len_query}_s{len_set}"
    )
    batch = {}
    batch["labels"] = labels.reshape(num_query, len_set, 1)
    batch["observations"] = seg_obs_1.reshape(
        num_query, len_set, len_query, observation_dim
    )
    batch["observations_2"] = seg_obs_2.reshape(
        num_query, len_set, len_query, observation_dim
    )
    batch["actions"] = seg_act_1.reshape(num_query, len_set, len_query, action_dim)
    batch["actions_2"] = seg_act_2.reshape(num_query, len_set, len_query, action_dim)
    with open(query_path, "wb") as fp:
        pickle.dump(batch, fp)

    return batch, query_path


def get_queries_from_multi(
    env,
    dataset,
    num_query,
    len_query,
    len_set,
    data_dir=None,
    skip_flag=0,  #temp
    save_queries=True,
    unsafe_ratio=1.0,
    trajectory_clip=False,
):
    _num_query = num_query
    num_query *= len_set

    os.makedirs(data_dir, exist_ok=True)
    trj_idx_list = new_get_trj_idx(dataset)  # get_nonmdp_trj_idx(env)
    

    # By default, the first half trajs are generated by unsafe policy and last half trajs are generated by safe policy
    data_size = len(dataset["observations"])
    i = 0
    print('The number of traj before delete: ', len(trj_idx_list))
    while trj_idx_list[i][1] <= data_size * 0.5:
        if np.random.uniform(0, 1) > unsafe_ratio:
            del trj_idx_list[i]
        else:
            i += 1
    print('The number of traj after delete: ', len(trj_idx_list))

    trj_idx_list = np.array(trj_idx_list)
    trj_len_list = trj_idx_list[:, 1] - trj_idx_list[:, 0] + 1
    labeler_info = np.zeros(len(trj_idx_list) - 1)

    #in some envs, most of last steps is repetitive, and first few steps are important and decides the mode. So we only use the first few steps to generate pairs
    TRAJECTORY_LEN_CLIP = {
        "SafetyBallCircle-multimodal-v0": 65,  #temp
        "SafetyAntVelocity-multimodal-v0": 65,
        "SafetyHalfCheetahVelocity-multimodal-v0": 65,
        "SafetySwimmerVelocity-multimodal-v0": 500}

    if env.spec.id in TRAJECTORY_LEN_CLIP and trajectory_clip:  
        max_len = max(len_query, TRAJECTORY_LEN_CLIP[env.spec.id])
        trj_len_list = np.clip(trj_len_list, a_min=0, a_max=max_len)  
        print(f'{env.spec.id} clip traj len to {max_len}')

    assert max(trj_len_list) >= len_query

    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros(
        (num_query, len_query)
    )

    total_info_set_1, total_info_set_2 = {}, {}
    for k in dataset.keys():
        if 'infos/' in k:
            shape = (num_query, len_query, len(dataset[k][0])) if not np.isscalar(dataset[k][0]) else (num_query, len_query)
            total_info_set_1[k], total_info_set_2[k] = np.zeros(shape), np.zeros(shape)

    observation_dim = dataset["observations"].shape[-1]
    total_obs_seq_1, total_obs_seq_2 = np.zeros(
        (num_query, len_query, observation_dim)
    ), np.zeros((num_query, len_query, observation_dim))
    total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros(
        (num_query, len_query, observation_dim)
    ), np.zeros((num_query, len_query, observation_dim))

    action_dim = dataset["actions"].shape[-1]
    total_act_seq_1, total_act_seq_2 = np.zeros(
        (num_query, len_query, action_dim)
    ), np.zeros((num_query, len_query, action_dim))

    total_timestep_1, total_timestep_2 = np.zeros(
        (num_query, len_query), dtype=np.int32
    ), np.zeros((num_query, len_query), dtype=np.int32)

    start_indices_1, start_indices_2 = np.zeros(num_query), np.zeros(num_query)
    time_indices_1, time_indices_2 = np.zeros(num_query), np.zeros(num_query)

    query_path = os.path.join(
        data_dir, f"queries_num{_num_query}_q{len_query}_s{len_set}"
    )
    # already_queried = []
    for query_count in tqdm(range(num_query), desc="get queries"):
        temp_count = 0
        labeler = -1
        while temp_count < 2:
            valid_idx = np.arange(len(trj_idx_list) - 1)[np.logical_not(labeler_info)]
            valid_len = trj_len_list[valid_idx]
            prob = valid_len / np.sum(valid_len)
            trj_idx = np.random.choice(valid_idx, p=prob)  
            len_trj = trj_len_list[trj_idx]

            if len_trj >= len_query and (
                temp_count == 0 or labeler_info[trj_idx] == labeler
            ):
                labeler = labeler_info[trj_idx]
                time_idx = np.random.choice(len_trj - len_query + 1)
                start_idx = trj_idx_list[trj_idx][0] + time_idx
                end_idx = start_idx + len_query

                assert end_idx <= trj_idx_list[trj_idx][1] + 1

                reward_seq = dataset["rewards"][start_idx:end_idx]
                obs_seq = dataset["observations"][start_idx:end_idx]
                next_obs_seq = dataset["next_observations"][start_idx:end_idx]
                act_seq = dataset["actions"][start_idx:end_idx]
                info_set = {k:dataset[k][start_idx:end_idx] for k in total_info_set_1.keys()}
                # timestep_seq = np.arange(time_idx + 1, time_idx + len_query + 1)
                timestep_seq = np.arange(1, len_query + 1)

                # skip flag 1: skip queries with equal rewards.
                if skip_flag == 1 and temp_count == 1:
                    if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                        continue
                # skip flag 2: keep queries with equal reward until 50% of num_query.
                if (
                    skip_flag == 2
                    and temp_count == 1
                    and query_count < int(0.5 * num_query)
                ):
                    if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                        continue
                # skip flag 3: keep queries with equal reward until 20% of num_query.
                if (
                    skip_flag == 3
                    and temp_count == 1
                    and query_count < int(0.2 * num_query)
                ):
                    if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                        continue

                if temp_count == 0:
                    start_indices_1[query_count] = start_idx
                    time_indices_1[query_count] = time_idx
                    total_reward_seq_1[query_count] = reward_seq
                    total_obs_seq_1[query_count] = obs_seq
                    total_next_obs_seq_1[query_count] = next_obs_seq
                    total_act_seq_1[query_count] = act_seq
                    total_timestep_1[query_count] = timestep_seq
                    for k in total_info_set_1.keys():
                        total_info_set_1[k][query_count] = info_set[k]
                else:
                    # if (start_idx, start_indices_1[query_count]) in already_queried:
                    #     continue
                    start_indices_2[query_count] = start_idx
                    time_indices_2[query_count] = time_idx
                    total_reward_seq_2[query_count] = reward_seq
                    total_obs_seq_2[query_count] = obs_seq
                    total_next_obs_seq_2[query_count] = next_obs_seq
                    total_act_seq_2[query_count] = act_seq
                    total_timestep_2[query_count] = timestep_seq
                    for k in total_info_set_2.keys():
                        total_info_set_2[k][query_count] = info_set[k]

                temp_count += 1
                # already_queried.append(
                #     (start_indices_2[query_count], start_indices_1[query_count])
                # )

    seg_reward_1 = total_reward_seq_1.copy()
    seg_reward_2 = total_reward_seq_2.copy()

    seg_obs_1 = total_obs_seq_1.copy()
    seg_obs_2 = total_obs_seq_2.copy()

    seg_next_obs_1 = total_next_obs_seq_1.copy()
    seg_next_obs_2 = total_next_obs_seq_2.copy()

    seq_act_1 = total_act_seq_1.copy()
    seq_act_2 = total_act_seq_2.copy()

    seq_timestep_1 = total_timestep_1.copy()
    seq_timestep_2 = total_timestep_2.copy()

    rational_labels = get_labels(seg_reward_1, seg_reward_2)

    start_indices_1 = np.array(start_indices_1, dtype=np.int32)
    start_indices_2 = np.array(start_indices_2, dtype=np.int32)
    time_indices_1 = np.array(time_indices_1, dtype=np.int32)
    time_indices_2 = np.array(time_indices_2, dtype=np.int32)

    batch = {}
    batch["labels"] = rational_labels.reshape(_num_query, len_set, 1)
    batch["observations"] = seg_obs_1.reshape(
        _num_query, len_set, len_query, observation_dim
    )  # for compatibility, remove "_1"
    batch["next_observations"] = seg_next_obs_1.reshape(
        _num_query, len_set, len_query, observation_dim
    )
    batch["actions"] = seq_act_1.reshape(_num_query, len_set, len_query, action_dim)
    batch["observations_2"] = seg_obs_2.reshape(
        _num_query, len_set, len_query, observation_dim
    )
    batch["next_observations_2"] = seg_next_obs_2.reshape(
        _num_query, len_set, len_query, observation_dim
    )
    batch["actions_2"] = seq_act_2.reshape(_num_query, len_set, len_query, action_dim)
    # batch["time_indices_1"] = time_indices_1.reshape(_num_query, len_set)
    # batch["time_indices_2"] = time_indices_2.reshape(_num_query, len_set)
    # batch["timestep_1"] = seq_timestep_1.reshape(_num_query, len_set, len_query)
    # batch["timestep_2"] = seq_timestep_2.reshape(_num_query, len_set, len_query)
    # batch["start_indices"] = start_indices_1.reshape(_num_query, len_set)
    # batch["start_indices_2"] = start_indices_2.reshape(_num_query, len_set)
    for k in total_info_set_1.keys():
        batch[k] = total_info_set_1[k].reshape(_num_query, len_set, len_query, -1)
        batch[k.replace('infos/', 'infos_2/')] = total_info_set_2[k].reshape(_num_query, len_set, len_query, -1)

    if save_queries:
        with open(query_path, "wb") as fp:
            pickle.dump(batch, fp)

    return batch, query_path


class SingleQueryGenerator:
    def __init__(self, env, dataset, len_query):
        self.env = env
        self.dataset = dataset
        self.len_query = len_query

        self.trj_idx_list = np.array(new_get_trj_idx(dataset))
        self.trj_len_list = self.trj_idx_list[:, 1] - self.trj_idx_list[:, 0] + 1
        assert max(self.trj_len_list) >= len_query

        self.labeler_info = np.zeros(len(self.trj_idx_list) - 1)

        self.info_keys = [k for k in dataset.keys() if 'infos/' in k]
        self.observation_dim = dataset["observations"].shape[-1]
        self.action_dim = dataset["actions"].shape[-1]


    def _sample_segment(self):
        valid_idx = np.arange(len(self.trj_idx_list) - 1)[np.logical_not(self.labeler_info)]
        valid_len = self.trj_len_list[valid_idx]
        prob = valid_len / np.sum(valid_len)
        trj_idx = np.random.choice(valid_idx, p=prob)
        len_trj = self.trj_len_list[trj_idx]
        if len_trj < self.len_query:
            return None
        time_idx = np.random.choice(len_trj - self.len_query + 1)
        start_idx = self.trj_idx_list[trj_idx][0] + time_idx
        end_idx = start_idx + self.len_query

        seg = {
            "rewards": self.dataset["rewards"][start_idx:end_idx],
            "observations": self.dataset["observations"][start_idx:end_idx],
            "next_observations": self.dataset["next_observations"][start_idx:end_idx],
            "actions": self.dataset["actions"][start_idx:end_idx],
            "infos": {k: self.dataset[k][start_idx:end_idx].reshape(end_idx-start_idx, -1) for k in self.info_keys},
        }
        return seg
    
    def generate_query(self):
        temp_count = 0
        segs = []
        while temp_count < 2:
            seg = self._sample_segment()
            if seg is None:
                continue
            segs.append(seg)
            temp_count += 1

        seg1, seg2 = segs
        label = int(np.sum(seg1["rewards"]) < np.sum(seg2["rewards"]))
        query = {
            "obs_1": seg1["observations"],
            "obs_2": seg2["observations"],
            "act_1": seg1["actions"],
            "act_2": seg2["actions"],
            "infos_1": seg1["infos"],
            "infos_2": seg2["infos"],
            "label": label
        }
        return query


def get_labels(seg_reward_1, seg_reward_2):
    from absl import flags
    # if hasattr(flags.FLAGS, 'label_by_adv') and flags.FLAGS.label_by_adv and "maze" in flags.FLAGS.env:
    #     sum_r_t_1 = seg_reward_1[...,-1] - seg_reward_1[..., 0]
    #     sum_r_t_2 = seg_reward_2[...,-1] - seg_reward_2[..., 0]
    #     binary_label = (sum_r_t_1 > sum_r_t_2).reshape(-1, 1).astype(np.float32)
    # else:
    sum_r_t_1 = np.sum(seg_reward_1, axis=-1)
    sum_r_t_2 = np.sum(seg_reward_2, axis=-1)
    binary_label = (sum_r_t_1 > sum_r_t_2).reshape(-1, 1).astype(np.float32)
    return binary_label

def get_demos_from_multi(
    env,
    dataset,
    num_demo,
    len_demo,
):
    trj_idx_list = new_get_trj_idx(dataset)  # get_nonmdp_trj_idx(env)
    print("Number of trajectories: ", len(trj_idx_list) - 1)

    # to-do: parallel implementation
    trj_idx_list = np.array(trj_idx_list)
    trj_len_list = trj_idx_list[:, 1] - trj_idx_list[:, 0] + 1

    reward_seq = np.zeros((num_demo, len_demo))

    observation_dim = dataset["observations"].shape[-1]
    obs_seq= np.zeros((num_demo, len_demo, observation_dim))

    action_dim = dataset["actions"].shape[-1]
    act_seq = np.zeros((num_demo, len_demo, action_dim))

    for demo_count in tqdm(range(num_demo), desc="get queries"):
        while True:
            trj_idx = np.random.choice(
                np.arange(len(trj_idx_list) - 1)
            )
            len_trj = trj_len_list[trj_idx]

            if len_trj > len_demo:
                time_idx = np.random.choice(len_trj - len_demo + 1)
                start_idx = trj_idx_list[trj_idx][0] + time_idx
                end_idx = start_idx + len_demo

                reward_seq[demo_count] = dataset["rewards"][start_idx:end_idx]
                obs_seq[demo_count] = dataset["observations"][start_idx:end_idx]
                act_seq[demo_count] = dataset["actions"][start_idx:end_idx]
                break

    reward_seq = reward_seq.copy()
    obs_seq = obs_seq.copy()
    act_seq = act_seq.copy()
    batch = {}
    batch["observations"] = obs_seq
    batch["actions"] = act_seq
    batch["rewards"] = reward_seq
    return batch