import pickle
import collections
import numpy as np



def compute_sample_preference(traj_1, traj_2, n_samples=1):
    """
    Compute the preference between two trajectories.
    traj_1: (
            states(np.array): [num_steps, state_dim], 
            actions(np.array): [num_steps, action_dim], 
            next_states(np.array): [num_steps, state_dim], 
            rewards(np.array): [num_steps,]
            )
    traj_2: (states, actions, next_states, rewards)
    """
    reward_sum_1 = np.sum(traj_1["context_rewards"])
    reward_sum_2 = np.sum(traj_2["context_rewards"])
    
    max_val = np.maximum(reward_sum_1, reward_sum_2)
    exp1 = np.exp(reward_sum_1 - max_val)
    exp2 = np.exp(reward_sum_2 - max_val)
    preference_prob_traj_1 = exp1 / (exp1 + exp2)
    preference_prob_traj_2 = exp2 / (exp1 + exp2)
    
    assert np.isclose(preference_prob_traj_1 + preference_prob_traj_2, 1.0)
    preference = np.random.choice([0, 1], n_samples, p=[preference_prob_traj_1, preference_prob_traj_2]) # This is an array of [num_samples,]
    traj_pairs = []
    
    for i in range(n_samples):
        index = np.random.randint(0, 200)
        if preference[i] == 0:
            query_state = traj_1["context_states"][index]
            optimal_action = traj_1["optimal_actions"][index]
        elif preference[i] == 1:
            query_state = traj_2["context_states"][index]
            optimal_action = traj_2["optimal_actions"][index]
        
        traj_pairs.append({
            'traj_1': {
                'context_states': traj_1["context_states"],
                'context_actions': traj_1["context_actions"],
                'context_next_states': traj_1["context_next_states"],
                'context_rewards': traj_1["context_rewards"]
            },
            'traj_2': {
                'context_states': traj_2["context_states"],
                'context_actions': traj_2["context_actions"],
                'context_next_states': traj_2["context_next_states"],
                'context_rewards': traj_2["context_rewards"]
            },
            'preference': preference[i],
            'preference_probs': [preference_prob_traj_1, preference_prob_traj_2],
            'query_state': query_state,
            'optimal_action': optimal_action,
        })
    
    return traj_pairs


num_traj = 3000
num_samples = 10 

data_path = "./datasets/metaworld_DIT_tasks45_trajs1000_p80_train.pkl"
save_filepath = "./datasets/preference_DPT_train.pkl"

with open(data_path, 'rb') as f:
    data = pickle.load(f)

trajs = collections.defaultdict(list)
for traj in data:
    trajs[traj['task_id']].append(traj)

preference_trajs = []
for task_id in trajs.keys():
    for _ in range(num_traj):
        traj_1 = trajs[task_id][np.random.randint(0, len(trajs[task_id]))]
        traj_2 = trajs[task_id][np.random.randint(0, len(trajs[task_id]))]
    
        traj_pairs = compute_sample_preference(traj_1, traj_2, num_samples)
        for traj_pair in traj_pairs:
            traj_pair['task_id'] = task_id
        preference_trajs.extend(traj_pairs)
        
with open(save_filepath, 'wb') as file:
    pickle.dump(preference_trajs, file)