import argparse
import glob
import pandas as pd 
from multiprocessing import Pool
import tqdm
import os 
from collections import defaultdict

data_folder = "./data/datasets/"


def generate_sequence(env_id, goal, args):
    path_folder = f"{data_folder}{args.name}/env/{env_id}/{goal}"
    path_sequence = f"{path_folder}/sequence.csv"
    q_policy = pd.read_csv(f"{path_folder}/Q_policy.csv")
    data_with_rewards = pd.read_csv(f"{path_folder}/data_with_reward.csv")
    
    rewards = data_with_rewards["reward"].values
    threshold = rewards.max()-rewards.std()
    
    
    initial_states = data_with_rewards["obs"].drop_duplicates()
    initial_states = initial_states.sample(min(args.nb_sequences,len(initial_states)), random_state=args.seed, replace=False)
    results = []
    
    seen_states = {}
    
    
    data_with_rewards_hash = defaultdict(lambda :-1)
    for i,row in data_with_rewards.iterrows():
        data_with_rewards_hash[row["next_obs"]] = row["reward"]
        
    
    data_with_rewards_transition_hash = {}
    for i,row in data_with_rewards.iterrows():
        data_with_rewards_transition_hash[(row["obs"], row["action"])] = (row["next_obs"], row["reward"])
        
    q_policy_hash = {}
    for i,row in q_policy.iterrows():
        q_policy_hash[row["state"]] = row["action"]
    
    
    del data_with_rewards
    del q_policy
    
    for i_s in tqdm.tqdm(initial_states, desc=f"env {env_id} goal {goal}", disable=True):
        state = i_s
        sequence = []
        list_state = [state]
        terminated = False
        if not data_with_rewards_hash[state] > threshold:
            while (not terminated) and (len(sequence) < args.max_steps):
                if state in seen_states:
                    result = seen_states[state]
                    if result == "not_terminated":
                        break
                    else:
                        terminated = True
                        sequence += result
                        break
                action = q_policy_hash[state]
                sequence.append(action)
                
                if (state, action) not in data_with_rewards_transition_hash:
                    break
                (state, reward) = data_with_rewards_transition_hash[(state, action)]
                   
                list_state.append(state)

                if reward >= threshold:
                    terminated = True
        
            if terminated:
                terminated_states = list_state
                for i,state in enumerate(terminated_states):
                    seen_states[state] = sequence[i:]                
                results.append({"state": i_s, "action": sequence})

            
    pd.DataFrame(results).to_csv(path_sequence, index=False)
    
def run_process(p):
    env = p[0]
    goal = p[1]
    generate_sequence(env, goal, args)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str, default='minigrid_go_to_few_env', help='dataset to use')
    parser.add_argument('--env_id', type=str, nargs="+", default=[])
    parser.add_argument('--goals', type=str, nargs="+", default=[])
    parser.add_argument("--n_jobs", type=int, default=50, help="Number of jobs to run in parallel")
    parser.add_argument("--nb_sequences", type=int, default=1e15, help="Number of sequences to generate")
    parser.add_argument("--max_steps", type=int, default=500, help="Maximum number of steps in a sequence")
    parser.add_argument("--seed", type=int, default=42, help="Seed for random number generator")
    args = parser.parse_args()
    
    path_folder = f"{data_folder}{args.name}/env/"
    
    if args.env_id == []:
        list_env_id = glob.glob(path_folder + "*/")
        list_env_id = [env.split("/")[-2] for env in list_env_id]
    else:
        list_env_id = args.env_id
    
    
    list_process = []
    
    for env_id in list_env_id:
        if args.goals == []:
            list_goals = glob.glob(path_folder + f"{env_id}/*/")
            list_goals = [goal.split("/")[-2] for goal in list_goals if os.path.exists(f"{goal}/Q_policy.csv") and not os.path.exists(f"{goal}/sequence.csv")  ] 
        else:
            list_goals = args.goals
        for goal in list_goals:
            list_process.append((env_id, goal))
    
    pool = Pool(args.n_jobs)
    
    for _ in tqdm.tqdm(pool.imap_unordered(run_process, list_process), total=len(list_process)):
        pass
    