# Process hallway data

import os, argparse
import numpy as np
import pickle  
import matplotlib.pyplot as plt
from src.envs import gridworld
from plot_utils.plot_simulated_data_gridworld import *
import pandas as pd


def hallway_action2idx(action):
    if action == 'up':
        return 3
    elif action == 'down':
        return 2
    elif action == 'left':
        return 1
    elif action == 'right':
        return 0
    elif action == 'wait':
        return 4

if __name__ == '__main__':

    GEN_DIR_NAME = 'data/experiment_hallway/'
    df = pd.read_csv(GEN_DIR_NAME + 'trials.tsv', sep='\t')
    df['uid'] = df['workerid'] + '_' + df['round'].astype(str)
    # rounds = df[df['exp_condition']=='f']['uid']
    rounds = df['uid']
    len(rounds.unique())

    grid_H, grid_W = 3, 5
    gw_pre = gridworld.GridWorld_v2(grid_H, grid_W,{},{})
    # create a joint reward map, agent1: yellow; agent2: blue, state = (y1,x1,y2,x2)
    target_location = (1,4,1,0)
    reward_map = np.zeros(((grid_H*grid_W)**2, 1))
    target_idx = gw_pre.pos2idx1(target_location)
    reward_map[target_idx] = 1
    states = [(1,4,c,d) for c in range(grid_H) for d in range(grid_W)] + [(a,b,1,0) for a in range(grid_H) for b in range(grid_W)]
    terminal_states1 = list(set(states)) # remove repetitive entries
    collide_states = [(a,b,a,b) for a in range(grid_H) for b in range(grid_W)]
    terminal_states = [target_location]

    gw = gridworld.GridWorld_v2(grid_H, grid_W,reward_map,terminal_states)

    action_list = [(i,j) for i in range(5) for j in range(5)]
    trajectories = []
    for id in rounds.unique():
        df_select = df[df['uid']==id]
        df_select_sorted = df_select.sort_values('round_step')
        states, actions = [], []
        for _, row in df_select_sorted.iterrows():
            state_4d = (row['participant_y'], row['participant_x'], row['learner_y'], row['learner_x'])
            states.append(gw.pos2idx1(state_4d))
            action_2d = (hallway_action2idx(row['participant_action']), hallway_action2idx(row['learner_action']))
            action_idx = action_list.index(action_2d)
            actions.append(action_idx)
        gw.reset(state_4d)
        _,_,next_state,_,_ = gw.step(action_2d)
        states.append(gw.pos2idx1(next_state))
        trajectories.append({'states':states, 'actions':actions})
    
    marginal_states = [(1,4,c,d) for c in range(grid_H) for d in range(grid_W)] + [(a,b,1,0) for a in range(grid_H) for b in range(grid_W)]
    terminal_states = list(set(marginal_states)) # remove repetitive entries
    terminal_states.remove(target_location)
    terminal_state_idx = [gw.pos2idx1(pos) for pos in terminal_states]

    success_idx = [i for i, traj in enumerate(trajectories) if traj['states'][-1]==gw.pos2idx1(target_location)]
    expert_idx = [i for i, traj in enumerate(trajectories) if traj['states'][-1]==gw.pos2idx1(target_location) and len(traj['states'])==7]
    fail_idx = [i for i, traj in enumerate(trajectories) if traj['states'][-1] in terminal_state_idx]

    idx_list = [success_idx, expert_idx, fail_idx]

    # save hyperparameters for inference
    P_a = gw.get_permutation_mat(action_list)
    sigmas = 0.01
    diff_square = np.unique([i**2+j**2 for i in range(gw.height) for j in range(gw.width)])
    diff_square = diff_square / np.max(diff_square)
    diff_map = diff_square.reshape((1,diff_square.shape[0]))
    generative_parameters = {'P_a':P_a, 'sigmas':sigmas, 'height':gw.height, 'width':gw.width, 'diff_map_guess':diff_map}

    for VERSION in range(1,4):
        if not os.path.exists(GEN_DIR_NAME+str(VERSION)):
            os.makedirs(GEN_DIR_NAME+str(VERSION))

        traj_selected = [trajectories[idx] for idx in idx_list[VERSION-1]]
        with open(GEN_DIR_NAME+str(VERSION)+'/expert_trajectories.pickle','wb') as handle:
            pickle.dump(traj_selected, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
        generative_parameters['uid'] = rounds.unique()[idx_list[VERSION-1]]
        with open(GEN_DIR_NAME+str(VERSION)+'/generative_parameters.pickle', 'wb') as handle:
            pickle.dump(generative_parameters, handle,protocol=pickle.HIGHEST_PROTOCOL)

        for i in range(0,min(50,len(traj_selected)),10):
            traj = traj_selected[i]
            states4d = [gw.idx12pos(s) for s in traj['states']]
            traj1 = np.array([(a,b) for (a,b,c,d) in states4d])
            traj2 = np.array([(c,d) for (a,b,c,d) in states4d])
            # plot the trajectories
            fig, ax = plt.subplots(1,2, figsize=(10,4))
            plot_gridworld_trajectories(grid_H, grid_W, {'states2d':traj1}, fig, ax[0])
            plot_gridworld_trajectories(grid_H, grid_W, {'states2d':traj2}, fig, ax[1])
            plt.tight_layout()
            fig.savefig(GEN_DIR_NAME+str(VERSION)+'/traj{}.png'.format(i))