import numpy as np
from envs import darkroom_env
from tqdm import tqdm

def generate_mdp_histories_from_envs(envs, n_hists, n_samples, rollin_type, mode, random_p):
    trajs = []
    for env_id, env in tqdm(enumerate(envs)):
        for j in range(n_hists):
            (
                context_states,
                context_actions,
                context_next_states,
                context_rewards,
            ) = rollin_mdp(env, rollin_type=rollin_type[env_id], mode=mode, random_p= random_p)
            for k in range(n_samples):
                query_state = env.sample_state()
                optimal_action = env.opt_action(query_state)

                traj = {
                    'query_state': query_state,
                    'optimal_action': optimal_action,
                    'context_states': context_states,
                    'context_actions': context_actions,
                    'context_next_states': context_next_states,
                    'context_rewards': context_rewards,
                    'goal': env.goal,
                }

                # Add perm_index for DarkroomEnvPermuted
                if hasattr(env, 'perm_index'):
                    traj['perm_index'] = env.perm_index

                trajs.append(traj)
    return trajs

def generate_darkroom_histories(goals, dim, horizon, **kwargs):
    envs = [darkroom_env.DarkroomEnv(dim, goal, horizon) for goal in goals]
    trajs = generate_mdp_histories_from_envs(envs, **kwargs)
    return trajs

def generate_mdp_pref_histories_from_envs(envs, n_hists, n_samples, random_p, in_traj=False):
    trajs = []
    for env_id, env in tqdm(enumerate(envs)):
        for j in range(n_hists):
            (
                context_states,
                context_actions_1,
                context_actions_2,
                pref_actions,
                non_pref_actions,
                context_next_states,
                context_rewards,
            ) = rollin_pref_mdp(env, random_p= random_p)
            
            for k in range(n_samples):
                
                if in_traj: #limit the query state to be in-traj states
                    idx = np.random.choice(range(len(context_states)))
                    query_state = context_states[idx]
                else:
                    query_state = env.sample_state() #query_state is a numpy array 
                    
                optimal_action = env.opt_action(query_state)

                traj = {
                    'query_state': query_state,
                    'optimal_action': optimal_action,
                    'context_states': context_states,
                    'context_actions_1': context_actions_1,
                    'context_actions_2': context_actions_2,
                    'pref_actions':pref_actions,
                    'non_pref_actions':non_pref_actions,
                    'context_next_states': context_next_states,
                    'context_rewards': context_rewards,
                    'goal': env.goal,
                }

                # Add perm_index for DarkroomEnvPermuted
                if hasattr(env, 'perm_index'):
                    traj['perm_index'] = env.perm_index

                trajs.append(traj)
    return trajs

def generate_darkroom_pref_histories(goals, dim, horizon, **kwargs):
    envs = [darkroom_env.DarkroomEnv(dim, goal, horizon) for goal in goals]
    trajs = generate_mdp_pref_histories_from_envs(envs, **kwargs)
    return trajs


def rollin_mdp(env, rollin_type, mode, random_p):
    rollin_types = ['uniform','expert']
    states = []
    actions = []
    next_states = []
    rewards = []

    state = env.reset()
    for i in range(env.horizon):

        if mode == 'step':
            rollin_type = rollin_types[np.random.choice(range(2),p=[random_p, 1-random_p])]
        
        if rollin_type == 'uniform':
            state = env.sample_state()
            action = env.sample_action()
        elif rollin_type == 'expert':
            action = env.opt_action(state)
        else:
            raise NotImplementedError
        next_state, reward = env.transit(state, action)

        states.append(state)
        actions.append(action)
        next_states.append(next_state)
        rewards.append(reward)
        state = next_state

    states = np.array(states)
    actions = np.array(actions)
    next_states = np.array(next_states)
    rewards = np.array(rewards)

    return states, actions, next_states, rewards

