import os
import time
import warnings
import numpy as np
import gym

import random
import torch
import pickle
from pathlib import Path

warnings.filterwarnings('ignore')


def load_queries_with_indices(dataset, num_query, len_query, saved_indices, saved_labels=None, label_type=1,
                              equivalence_threshold=0, modality="state", partition_idx=None,human_label_flag=False):
    
    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))
    total_penalty_seq_1, total_penalty_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query)) 

    if modality == "state":
        observation_dim = (dataset["observations"].shape[-1], )
    elif modality == "pixel":
        observation_dim = dataset["observations"].shape[-3:]
    else:
        raise ValueError("Modality error")

    action_dim = dataset["actions"].shape[-1]

    total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query) + observation_dim), np.zeros(
        (num_query, len_query) + observation_dim)
    total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros(
        (num_query, len_query, action_dim))
    total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros(
        (num_query, len_query), dtype=np.int32)
    
    if saved_labels is None:
        query_range = np.arange(num_query)
    else:
        # do not query all label
        if partition_idx is None:
            query_range = np.arange(len(saved_labels) - num_query, len(saved_labels))
        else:
            # If dataset is large, you should load the dataset in slices.
            query_range = np.arange(partition_idx * num_query, (partition_idx + 1) * num_query)

    for query_count, i in enumerate(query_range):
        temp_count = 0
        while temp_count < 2:
            start_idx = saved_indices[temp_count][i]
            end_idx = start_idx + len_query
            
            reward_seq = dataset['rewards'][start_idx:end_idx]
            obs_seq = dataset['observations'][start_idx:end_idx]
            act_seq = dataset['actions'][start_idx:end_idx]
            timestep_seq = np.arange(1, len_query + 1)

            # if human_label_flag is False: 
            #     penalty_seq = dataset['penalties'][start_idx:end_idx]
            
            if temp_count == 0:
                total_reward_seq_1[query_count] = reward_seq.reshape(-1)
                #total_reward_seq_1[query_count] = reward_seq
                total_obs_seq_1[query_count] = obs_seq
                total_act_seq_1[query_count] = act_seq
                total_timestep_1[query_count] = timestep_seq

                # if human_label_flag is False: 
                #     total_penalty_seq_1[query_count] = penalty_seq.reshape(-1)
            else:
                total_reward_seq_2[query_count] = reward_seq.reshape(-1)
                #total_reward_seq_2[query_count] = reward_seq
                total_obs_seq_2[query_count] = obs_seq
                total_act_seq_2[query_count] = act_seq
                total_timestep_2[query_count] = timestep_seq
                
                # if human_label_flag is False: 
                #     total_penalty_seq_2[query_count] = penalty_seq.reshape(-1)                
            
            temp_count += 1
    
    seg_reward_1 = total_reward_seq_1.copy()
    seg_reward_2 = total_reward_seq_2.copy()
    
    seg_obs_1 = total_obs_seq_1.copy()
    seg_obs_2 = total_obs_seq_2.copy()
    
    seq_act_1 = total_act_seq_1.copy()
    seq_act_2 = total_act_seq_2.copy()
    
    seq_timestep_1 = total_timestep_1.copy()
    seq_timestep_2 = total_timestep_2.copy()

    # seq_penalty_1 = total_penalty_seq_1.copy()
    # seq_penalty_2 = total_penalty_seq_2.copy()

    batch = {}
    # script_labels
    # label_type = 0 perfectly rational / label_type = 1 equivalence_threshold
    if label_type == 0:  # perfectly rational
        sum_r_t_1 = np.sum(seg_reward_1, axis=1)
        sum_r_t_2 = np.sum(seg_reward_2, axis=1)
        binary_label = 1 * (sum_r_t_1 < sum_r_t_2)
        rational_labels = np.zeros((len(binary_label), 2))
        rational_labels[np.arange(binary_label.size), binary_label] = 1.0
    elif label_type == 1:
        sum_r_t_1 = np.sum(seg_reward_1, axis=1)
        sum_r_t_2 = np.sum(seg_reward_2, axis=1)
        binary_label = 1 * (sum_r_t_1 < sum_r_t_2)
        rational_labels = np.zeros((len(binary_label), 2))
        rational_labels[np.arange(binary_label.size), binary_label] = 1.0
        margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) <= equivalence_threshold).reshape(-1)
        rational_labels[margin_index] = 0.5
    batch['script_labels'] = rational_labels

    if human_label_flag is False:
        # fake label
        batch['labels'] = saved_labels
    else:
        # human label
        human_labels = np.zeros((len(saved_labels), 2))
        human_labels[np.array(saved_labels) == 0, 0] = 1.
        human_labels[np.array(saved_labels) == 1, 1] = 1.
        human_labels[np.array(saved_labels) == -1] = 0.5
        human_labels = human_labels[query_range]
        batch['labels'] = human_labels
        # print(batch['labels'])
    
    batch['observations_1'] = seg_obs_1
    batch['actions_1'] = seq_act_1
    batch['observations_2'] = seg_obs_2
    batch['actions_2'] = seq_act_2
    batch['timestep_1'] = seq_timestep_1
    batch['timestep_2'] = seq_timestep_2
    batch['start_indices_1'] = saved_indices[0]
    batch['start_indices_2'] = saved_indices[1]
    
    # if human_label_flag is False:
    #     sum_p_1 = np.sum(seq_penalty_1, axis=1)
    #     sum_p_2 = np.sum(seq_penalty_2, axis=1)        
    #     batch['seq_penalty_1'] = sum_p_1
    #     batch['seq_penalty_2'] = sum_p_2 
    return batch


def load_preference_dataset(args,dataset,num_query, len_query,human_label=True):
    
    if human_label:
        path = os.path.join(args.preference_data_path, "human_labels")
        data_dir = os.path.join(path, f"{args.task}_human_labels")
    else:
        path = os.path.join(args.preference_data_path, "fake_labels")
        data_dir = os.path.join(path, f"{args.task}_fake_labels")
        
    if os.path.exists(data_dir):
        suffix = f"_domain_{args.domain}_env_{args.task}_num_{args.num_query}_len_{args.len_query}"
        matched_file = []
        for file_name in os.listdir(data_dir):
            if suffix in file_name:
                matched_file.append(file_name)
        label_file, indices_1_file, indices_2_file = sorted(matched_file)

        with open(os.path.join(data_dir, label_file), "rb") as fp:  # Unpickling
            label_file = pickle.load(fp)
        with open(os.path.join(data_dir, indices_1_file), "rb") as fp:  # Unpickling
            human_indices_1 = pickle.load(fp)
        with open(os.path.join(data_dir, indices_2_file), "rb") as fp:  # Unpickling
            human_indices_2 = pickle.load(fp)     
    else:
        raise ValueError(f"Label not found")
    
    pref_dataset = load_queries_with_indices(
            dataset, num_query, len_query, saved_indices=[human_indices_1, human_indices_2],
            saved_labels=label_file,human_label_flag=human_label)
    
    return pref_dataset

