# Modified based on the original script from dynamic_irl-main
# Convert time-variying IRL to time-invariant IRL
# Adapted to a two agent collaborative foraging task

# Given a set of hyperparameters, first simulate a reward trajectory using these hyperparameters
# Next, generate trajectory data within the gridworld environment
import os, argparse
import numpy as np
import matplotlib.pyplot as plt
from src.envs import gridworld
from collections import namedtuple
import pickle
from plot_utils.plot_simulated_data_gridworld import plot_rewards_all, plot_gridworld_trajectories
from src.value_iteration import *
np.random.seed(50)

Step = namedtuple('Step','cur_state action next_state reward done')

def create_goal_maps(num_gridworld_states, LOCATION_WATER, LOCATION_HOME):
    '''
    create goal reward maps
    '''
    home_map = np.zeros((num_gridworld_states))
    home_map[LOCATION_HOME] = 1

    water_map = np.zeros((num_gridworld_states))
    water_map[LOCATION_WATER] = 1

    goal_maps = np.array([home_map, water_map])
    return goal_maps


def generate_expert_trajectories(gw, policy, action_list, T):
    '''
    given reward trajectories, generate state-action trajectories, assuming that these agents act optimally
    under the provided reward function
    INPUT:
        gw: gridworld object (containing gridworld information and reward information)
        time_invariant_rewards: (num_states, 1): actual reward on the map
        policy: array of size (num_states**2, num_actions**2)
        T: length of trajectory to generate
    OUTPUT:
        trajectory  - a list of Steps representing an episode
    '''

    # initial reward map
    # r_map = np.reshape(np.array(time_invariant_rewards[:,0]), (grid_H,grid_W), order='F') # reshape along columns
    # gw = gridworld.GridWorld(r_map, {},)
    # gw = gridworld.GridWorld_wrapped(r_map, {},)
    cur_state = gw.get_current_state() #current state
    states = [gw.pos2idx1(cur_state)] # save states in their rolled out 1d rep from 1 to num_states
    states4d = [cur_state] # save x,y coordinates of states
    actions = []
    actions2d = []
    rewards = []
    for t in range(0, T-1):
        action_idx = np.random.choice(range(policy.shape[1]), p=policy[gw.pos2idx1(cur_state)]) # take action
        cur_state, action, next_state, reward, is_done = gw.step(action_list[action_idx]) # update current state
        states.append(gw.pos2idx1(next_state))
        states4d.append(next_state)
        actions2d.append(action)
        actions.append(action_idx)
        rewards.append(reward)
        if is_done: break
    # create a trajectory dict
    traj = {'states': np.array(states), 'states4d': np.array(states4d),
            'actions': np.array(actions), 'actions2d': np.array(actions2d),
            'rewards': np.array(rewards)}
    return traj



def main(gridworld_H, gridworld_W, N_experts, T, reward_strength,
         LOCATION_WATER, LOCATION_HOME, VERSION = 1, TAG=0, GAMMA=0.9, plot_data=False):
    
    tags = ['centralized', 'independent_control', 'independent_control_w_uniform_prediction', 'independent_control_w_policy_prediction']
    # for tags[3], only agent 1 is predictive of agent 2's policy

    # select number of maps
    num_maps = 2
    # sigmas = 2**-(3.5) # uniform covariance for both reward locations
    sigmas = 0.01
    goal_maps = create_goal_maps(gridworld_H*gridworld_W, LOCATION_WATER, LOCATION_HOME) # num_maps x num_states
    time_invariant_weights = np.random.normal(1., scale=sigmas, size=(1,num_maps)) # 1 x num_maps

    # now obtain reward maps
    rewards = time_invariant_weights@goal_maps*reward_strength # 1 x num_states
    rewards = rewards.T # num_states x 1

    r_map = np.reshape(np.array(rewards[:,0]), (gridworld_H,gridworld_W), order='F')

    gw = gridworld.GridWorld(r_map, {},)
    action_list = [(i,j) for i in range(5) for j in range(5)]
    P_a = gw.get_permutation_mat(action_list)
    reward_joint = np.zeros(((gridworld_H*gridworld_W)**2,1))
    for i in range(reward_joint.shape[0]):
        cur = gw.idx12pos(i)
        if cur[0] == cur[2] and cur[1] == cur[3]:
            reward_joint[i] = r_map[cur[0], cur[1]] 
    
    reward_idx = np.argwhere(reward_joint[:,0]!=0)[:,0]
    terminal_states = [gw.idx12pos(r) for r in reward_idx]
    
    if VERSION == 1 or VERSION == 3: # add terminal states
        gw = gridworld.GridWorld(r_map, terminal_states)

    gamma = 0.9
    if TAG == 0:
        values, policy = two_agent_value_iteration(P_a.astype(int), reward_joint, gamma)
    elif TAG == 1:
        values, policy,_,_ = two_agent_value_iteration_independent_control(P_a.astype(int), reward_joint, gamma)
    elif TAG == 2:
        values, policy,_,_ = two_agent_value_iteration_independent_control_uniform_prediction(P_a.astype(int), reward_joint, gamma)
    elif TAG == 3:
        gw_single = gridworld.GridWorld_SingleAgent(r_map,{},)
        P_a_single = gw_single.get_permutation_mat()
        values, policy = two_agent_value_iteration_independent_control_policy_prediction(P_a.astype(int),P_a_single,reward_joint, rewards, gamma)

    trajs_all_experts = []
    success = 0
    for expert in range(N_experts):
        start_pos = (np.random.randint(0,gw.height), np.random.randint(0,gw.width), np.random.randint(0,gw.height), np.random.randint(0,gw.width))
        if VERSION == 6:
            start_pos = (int(gw.height/2), int(gw.width/2), int(gw.height/2), int(gw.width/2))
        gw.reset(start_pos)
        traj = generate_expert_trajectories(gw, policy, action_list, T)
        trajs_all_experts.append(traj)
        if np.sum(traj['rewards'])>0: success += 1

    generative_parameters = {}
    generative_parameters['P_a'] = P_a
    generative_parameters['generative_rewards'] = rewards
    generative_parameters['goal_maps'] = goal_maps
    generative_parameters['time_invariant_weights'] = time_invariant_weights
    generative_parameters['sigmas'] = sigmas
    generative_parameters['policy'] = policy
    generative_parameters['success_rate'] = success / N_experts

    # now save everything
    save_dir = 'data/simulated_gridworld_data/{}{}'.format(tags[TAG], VERSION)
    
    # check if save_dir exists, else create it
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir, exist_ok = True)

    with open(save_dir+'/expert_trajectories.pickle','wb') as handle:
        pickle.dump(trajs_all_experts, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(save_dir+'/generative_parameters.pickle', 'wb') as handle:
        pickle.dump(generative_parameters, handle,protocol=pickle.HIGHEST_PROTOCOL)
    if plot_data:
        plot_rewards_all(goal_maps, np.tile(time_invariant_weights,(T,1)),
                                    np.tile(rewards,(T,1)),
                                    gridworld_H, gridworld_W,LOCATION_WATER, LOCATION_HOME,
                                    save_name=save_dir+'/generative_rewards.png')
        # plot a few example trajectories:
        for i in [0, 5, 10, 15, 20]:
            traj = trajs_all_experts[i]
            traj1 = np.array([(a,b) for (a,b,c,d) in traj['states4d']])
            traj2 = np.array([(c,d) for (a,b,c,d) in traj['states4d']])
            fig, ax = plt.subplots(1,2, figsize=(10,4))
            plot_gridworld_trajectories(gridworld_H, gridworld_W, {'states2d':traj1}, fig, ax[0])
            plot_gridworld_trajectories(gridworld_H, gridworld_W, {'states2d':traj2}, fig, ax[1])
            plt.tight_layout()
            fig.savefig(save_dir+'/traj{}.png'.format(i))
        
        
        values_mat = np.zeros((gridworld_H*gridworld_W,gridworld_W*gridworld_H))
        for i in range(values.shape[0]):
            idx1, idx2 = i % (gridworld_H*gridworld_W), i // (gridworld_H*gridworld_W)
            values_mat[idx1,idx2] = values[i,0]

        fig = plt.figure(figsize=(7,7))
        plt.imshow(values_mat); plt.colorbar()
        plt.ylabel('Agent 1 idx'); plt.xlabel('Agent 2 idx')
        plt.title('Success rate: {}; T: {}'.format(success/N_experts, T))
        fig.savefig(save_dir+'/value.png')



if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='enter environment specifics')
    parser.add_argument('--TAG', type=int, default=0, help='which interaction type to use')
    parser.add_argument('--VERSION', type=int, default=1, help='which grid env to use')

    args = parser.parse_args()
    VERSION = args.VERSION
    TAG = args.TAG

    gridworld_H, gridworld_W = 5, 5
    if VERSION == 3:
        gridworld_H, gridworld_W = 9, 9
    N_experts = 500 # number of trajectories to generate
    T = 200 # length of each trajectory

    # set locations of home and water port
    LOCATION_WATER = gridworld_H * int(gridworld_W/2)
    LOCATION_HOME = (gridworld_W-1) * gridworld_H + int(gridworld_H/2)
    reward_strength = 2
    if VERSION == 4:
        reward_strength = 1
    if VERSION == 5:
        reward_strength = 5

    main(gridworld_H, gridworld_W, N_experts, T, reward_strength,
         LOCATION_WATER, LOCATION_HOME, VERSION, TAG, GAMMA=0.9, plot_data=True)
