"""
This script provides functions to generate trajectories for RL environments.
"""

import os

import pickle
import torch

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

from envs import *
from utils import build_darkroom_data_filename

def rollin_mdp(env, rollin_type, mode, random_p):
    rollin_types = ['uniform','expert']
    states = []
    actions = []
    next_states = []
    rewards = []

    state = env.reset()
    for i in range(env.horizon):

        if mode == 'step':
            rollin_type = rollin_types[np.random.choice(range(2),p=[random_p, 1-random_p])]
        
        if rollin_type == 'uniform':
            state = env.sample_state()
            action = env.sample_action()
        elif rollin_type == 'expert':
            action = env.opt_action(state)
        else:
            raise NotImplementedError
        next_state, reward = env.transit(state, action)

        states.append(state)
        actions.append(action)
        next_states.append(next_state)
        rewards.append(reward)
        state = next_state

    states = np.array(states)
    actions = np.array(actions)
    next_states = np.array(next_states)
    rewards = np.array(rewards)
    
    return states, actions, next_states, rewards
    
def generate_mdp_histories_from_envs(envs, n_hists, n_samples, rollin_type, mode, random_p):
    """
    Generate MDP trajectories from the environments.
    1. Sample n_hists trajectories for each environment.
    2. For each trajectory, sample n_samples query states.
    3. For each query state, compute the optimal action.
    4. Store the trajectory and the optimal action in a dictionary.

    Args:
        envs (_type_): _description_
        n_hists (_type_): _description_
        n_samples (_type_): _description_
        rollin_type (_type_): _description_
        mode (_type_): _description_
        random_p (_type_): _description_

    Returns:
        _type_: _description_
    """
    trajs = []
    for env_id, env in tqdm(enumerate(envs)):
        for j in range(n_hists):
            (
                context_states,
                context_actions,
                context_next_states,
                context_rewards,
            ) = rollin_mdp(env, rollin_type=rollin_type[env_id], mode=mode, random_p= random_p)
            
            for k in range(n_samples):
                query_state = env.sample_state()
                optimal_action = env.opt_action(query_state)

                traj = {
                    'query_state': query_state,
                    'optimal_action': optimal_action,
                    'context_states': context_states,
                    'context_actions': context_actions,
                    'context_next_states': context_next_states,
                    'context_rewards': context_rewards,
                    'goal': env.goal,
                }
                
                if hasattr(env, 'perm_index'):
                    traj['perm_index'] = env.perm_index
                    
                trajs.append(traj)
                
    return trajs

def compute_sample_preference(traj_1, traj_2, n_samples=1):
    """
    Compute the preference between two trajectories.
    traj_1: (
            states(np.array): [num_steps, state_dim], 
            actions(np.array): [num_steps, action_dim], 
            next_states(np.array): [num_steps, state_dim], 
            rewards(np.array): [num_steps,]
            )
    traj_2: (states, actions, next_states, rewards)
    """
    reward_sum_1 = np.sum(traj_1[3])
    reward_sum_2 = np.sum(traj_2[3])
    
    preference_prob_traj_1 = np.exp(reward_sum_1) / (np.exp(reward_sum_1) + np.exp(reward_sum_2))
    preference_prob_traj_2 = np.exp(reward_sum_2) / (np.exp(reward_sum_1) + np.exp(reward_sum_2))
    
    assert np.isclose(preference_prob_traj_1 + preference_prob_traj_2, 1.0)
    preference = np.random.choice([0, 1], n_samples, p=[preference_prob_traj_1, preference_prob_traj_2]) # This is an array of [num_samples,]
    
    traj_pairs = []
    
    for i in range(n_samples):
        traj_pairs.append({
            'traj_1': {
                'context_states': traj_1[0],
                'context_actions': traj_1[1],
                'context_next_states': traj_1[2],
                'context_rewards': traj_1[3]
            },
            'traj_2': {
                'context_states': traj_2[0],
                'context_actions': traj_2[1],
                'context_next_states': traj_2[2],
                'context_rewards': traj_2[3]
            },
            'preference': preference[i],
            'preference_probs': [preference_prob_traj_1, preference_prob_traj_2]
        })
    
    return traj_pairs

def generate_preference_histories_from_envs(envs, n_hists, n_samples, rollin_type, mode, random_p):
    """_summary_

    Args:
        envs (_type_): _description_
        n_hists (_type_): _description_
        n_samples (_type_): _description_
        rollin_type (_type_): _description_
        mode (_type_): _description_
        random_p (_type_): _description_
    """
    preference_trajs = []
    for env_id, env in tqdm(enumerate(envs)):
        for j in range(n_hists):
            # Generate two trajectories
            traj_1 = rollin_mdp(env, rollin_type, mode, random_p) #(states, actions, next_states, rewards)
            traj_2 = rollin_mdp(env, rollin_type, mode, random_p) #(states, actions, next_states, rewards)
            # Compute the preference
            traj_pairs = compute_sample_preference(traj_1, traj_2, n_samples)
            for traj_pair in traj_pairs:
                traj_pair['goal'] = env.goal
            preference_trajs.extend(traj_pairs)
    
    return preference_trajs

def generate_darkroom_histories(goals, dim, horizon, func=generate_preference_histories_from_envs ,**kwargs):
    """
    Generate Darkroom trajectories with the given goals and dimensions.

    Args:
        goals (_type_): _description_
        dim (_type_): _description_
        horizon (_type_): _description_

    Returns:
        _type_: _description_
    """
    envs = [DarkroomEnv(dim, goal, horizon) for goal in goals]
    trajs = func(envs, **kwargs)
    return trajs

if __name__ == '__main__':
    n_hists, n_samples, horizon, data_type = 5, 10, 100, 'preference'
    config = {
            'n_hists': n_hists,  
            'n_samples': n_samples,
            'horizon': horizon, # length of each trajectory
    }
    
    if data_type == 'preference':
        func = generate_preference_histories_from_envs
    elif data_type == 'standard':
        func = generate_mdp_histories_from_envs
    else:
        raise NotImplementedError

    # Define the environment parameters
    dim = 10
    n_envs = 10000
    config.update({'dim': dim, 'rollin_type': 'uniform', 'mode':'step', 'func': func})
    goals = np.array([[(j, i) for i in range(dim)]
                    for j in range(dim)]).reshape(-1, 2)
    np.random.RandomState(seed=32).shuffle(goals)
    train_test_split = int(.8 * len(goals))
    train_goals = goals[:train_test_split]
    test_goals = goals[train_test_split:]

    # construct the goals for training and testing: the goals in testing are different to those in training
    eval_goals = np.array(test_goals.tolist() *
                        int(100 // len(test_goals)))
    train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
    test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)


    # Generate the training and testing trajectories
    for random_p in [0.2]:
        config['random_p']=random_p
        # the 'rollin_type' option is not used
        config['rollin_type']=['uniform']*len(train_goals)
        train_trajs = generate_darkroom_histories(train_goals, **config)
        
        config['rollin_type']=['uniform']*len(test_goals)
        test_trajs = generate_darkroom_histories(test_goals, **config,)
        
        config['rollin_type']=['uniform']*len(eval_goals)
        eval_trajs = generate_darkroom_histories(eval_goals, **config)
        
        env = 'darkroom_step'
        
        train_filepath = build_darkroom_data_filename(
                    env, n_envs, config, random_p, mode=0)
        test_filepath = build_darkroom_data_filename(
            env, n_envs, config, random_p, mode=1)
        eval_filepath = build_darkroom_data_filename(env, 100, config, random_p, mode=2)
        
        if data_type == 'preference':
            train_filepath = train_filepath.replace('standard', 'preference')
            test_filepath = test_filepath.replace('standard', 'preference')
            eval_filepath = eval_filepath.replace('standard', 'preference')

        # Saving Generated Trajectories
        if not os.path.exists('datasets'):
            os.makedirs('datasets', exist_ok=True)

        with open(train_filepath, 'wb') as file:
            pickle.dump(train_trajs, file)
        with open(test_filepath, 'wb') as file:
            pickle.dump(test_trajs, file)
        with open(eval_filepath, 'wb') as file:
            pickle.dump(eval_trajs, file)