import os
import pickle

from tqdm import tqdm
import gym
import d4rl
import numpy as np

class DatasetSampler:
    """Specially customized sampler for d4rl"""
    
    def __init__(self, cfg, **kwargs):
        self.cfg = cfg
        self.len_query = cfg["rollout_length"]
        self.num_query = cfg["rollout_batch_size"]*cfg["model_retain_epochs"]
        self.max_episode_length = cfg["max_episode_length"]
        self.dataset = kwargs["dataset"]
        self.task = cfg["task"]
    
    def get_episode_boundaries(self, **kwargs):
        dataset = kwargs['dataset']
        N = dataset['rewards'].shape[0]

        # The newer version of the dataset adds an explicit
        # timeouts field. Keep old method for backwards compatability.
        use_timeouts = False
        if 'timeouts' in dataset:
            use_timeouts = True
        
        episode_step = 0
        start_idx, data_idx = 0, 0
        trj_idx_list = []
        for i in range(N - 1):
            if 'maze' in self.task:
                done_bool = sum(dataset['goals'][i + 1] - dataset['goals'][i]) > 0
            else:
                done_bool = bool(dataset['terminals'][i])
                done_bool = bool(dataset['terminals'][i])

            if use_timeouts:
                final_timestep = dataset['timeouts'][i]
            else:
                final_timestep = (episode_step == self.max_episode_length - 1)

            if final_timestep:
                # Skip this transition and don't apply terminals on the last step of an episode
                episode_step = 0
                trj_idx_list.append([start_idx, data_idx - 1])
                start_idx = data_idx

            if done_bool:
                episode_step = 0
                trj_idx_list.append([start_idx, data_idx])
                start_idx = data_idx + 1
            
            episode_step += 1
            data_idx += 1
        
        trj_idx_list.append([start_idx, data_idx])
        return trj_idx_list
    
    def sample(self):
        '''
            sample num_query*len_query sequences
        '''
        trj_idx_list = self.get_episode_boundaries(dataset=self.dataset)
        trj_idx_list = np.array(trj_idx_list)
        trj_len_list = trj_idx_list[:, 1] - trj_idx_list[:, 0] + 1  # len(trj_len_list) = dataset episode num
        # print(trj_len_list)
        
        assert max(trj_len_list) > self.len_query
        
        start_indices_1, start_indices_2 = np.zeros(self.num_query), np.zeros(self.num_query)
        end_indices_1, end_indices_2 = np.zeros(self.num_query), np.zeros(self.num_query)
        
        for query_count in range(self.num_query):
            temp_count = 0
            while temp_count < 2:
                trj_idx = np.random.choice(np.arange(len(trj_idx_list) - 1))
                len_trj = trj_len_list[trj_idx]
                
                if len_trj > self.len_query:
                    time_idx = np.random.choice(len_trj - self.len_query + 1)
                    start_idx = trj_idx_list[trj_idx][0] + time_idx
                    end_idx = start_idx + self.len_query
                    
                    assert end_idx <= trj_idx_list[trj_idx][1] + 1
                    
                    if temp_count == 0:
                        start_indices_1[query_count] = start_idx
                        end_indices_1[query_count] = end_idx
                    else:
                        start_indices_2[query_count] = start_idx
                        end_indices_2[query_count] = end_idx
                    
                    temp_count += 1
        
        start_indices_1 = np.array(start_indices_1, dtype=np.int32)  # shape: (10, )
        start_indices_2 = np.array(start_indices_2, dtype=np.int32)
        end_indices_1 = np.array(end_indices_1, dtype=np.int32)
        end_indices_2 = np.array(end_indices_2, dtype=np.int32)

        return start_indices_1, start_indices_2, end_indices_1, end_indices_2
    
    
def get_fake_labels_with_indices(dataset, num_query, len_query, saved_indices, end_indices,equivalence_threshold=0):
    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))

    #total_penalty_seq_1, total_penalty_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query)) 

    query_range = np.arange(num_query)
    for query_count, i in enumerate(tqdm(query_range, desc="get queries from saved indices")):
        temp_count = 0
        while temp_count < 2:
            start_idx = saved_indices[temp_count][i]
            #end_idx = end_indices[temp_count][i]
            end_idx = start_idx + len_query
            
            reward_seq = dataset['rewards'][start_idx:end_idx]

            # penalty = dataset['penalties'][start_idx:end_idx]
            
            if temp_count == 0:
                total_reward_seq_1[query_count] = reward_seq.reshape(-1)
                #total_penalty_seq_1[query_count] = penalty.reshape(-1)
            else:
                total_reward_seq_2[query_count] = reward_seq.reshape(-1)
                #total_penalty_seq_2[query_count] = penalty.reshape(-1)
                
            temp_count += 1
    
    seg_reward_1 = total_reward_seq_1.copy()
    seg_reward_2 = total_reward_seq_2.copy()

    # seg_penalty_1 = total_penalty_seq_1.copy()
    # seg_penalty_2 = total_penalty_seq_2.copy()
    
    batch = {}
    
    # script_labels
    sum_r_t_1 = np.sum(seg_reward_1, axis=1)
    sum_r_t_2 = np.sum(seg_reward_2, axis=1)

    # sum_p_t_1 = np.sum(seg_penalty_1, axis=1)
    # sum_p_t_2 = np.sum(seg_penalty_2, axis=1)

    binary_label = 1 * (sum_r_t_1 < sum_r_t_2)
    rational_labels = np.zeros((len(binary_label), 2))
    rational_labels[np.arange(binary_label.size), binary_label] = 1.0
    margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= equivalence_threshold).reshape(-1)
    rational_labels[margin_index] = 0.5
    
    batch['script_labels'] = rational_labels
    batch['start_indices_1'] = saved_indices[0]
    batch['start_indices_2'] = saved_indices[1]
    # batch['penalties'] = sum_p_t_1 + sum_p_t_2
    
    return batch

def collect_preference_data(args,dataset,num_query,len_query,human_label=True):
    # some functions need dict parameters
    cfg = vars(args)
    sampler = DatasetSampler(cfg, dataset=dataset)
    start_indices_1, start_indices_2, end_indices_1, end_indices_2 = sampler.sample()
    
    # script_labels, start_indices, start_indices_2
    # customize equivalence threshold
    equivalence_threshold_dict = {"mujoco": 10, "antmaze": 0, "adroit": 0}
    batch = get_fake_labels_with_indices(
        dataset,
        num_query=num_query,
        len_query=len_query,
        saved_indices=[start_indices_1, start_indices_2],
        end_indices=[end_indices_1, end_indices_2],
        equivalence_threshold=equivalence_threshold_dict[args.domain]
    )

    if human_label:
        path = os.path.join(args.preference_data_path, "human_labels")
        save_dir = os.path.join(path, f"{args.task}_human_labels")
    else:
        path = os.path.join(args.preference_data_path, "fake_labels")
        save_dir = os.path.join(path, f"{args.task}_fake_labels")
    
    #identifier = str(uuid.uuid4().hex)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    suffix = f"_domain_{args.domain}_env_{args.task}_num_{args.num_query}_len_{args.len_query}"#_{identifier}"

    with open(os.path.join(save_dir,
                           "indices_1" + suffix + ".pkl"),
              "wb",
              ) as f:
        pickle.dump(batch["start_indices_1"], f)
    with open(os.path.join(save_dir,
                           "indices_2" + suffix + ".pkl"),
              "wb",
              ) as f:
        pickle.dump(batch["start_indices_2"], f)
    with open(os.path.join(save_dir,
                           "fake_label" + suffix + ".pkl"),
              "wb",
              ) as f:
        pickle.dump(batch["script_labels"], f)
    
    # if human_label is not True:
    #     with open(os.path.join(save_dir,
    #                        "penalties" + suffix + ".pkl"),
    #           "wb",
    #           ) as f:
    #         pickle.dump(batch["penalties"], f)