# Modified based on simulate_data_gridworld_ind.py
# Adapted to a two agent collaborative foraging task
# takes in different policies by agent1 and agent2 and the exeute independently


import os, argparse
import numpy as np
import matplotlib.pyplot as plt
from src.envs import gridworld
from collections import namedtuple
import pickle
from plot_utils.plot_simulated_data_gridworld import plot_rewards_all, plot_gridworld_trajectories
from src.value_iteration import *
np.random.seed(50)

Step = namedtuple('Step','cur_state action next_state reward done')

def create_goal_maps(num_gridworld_states, LOCATION_WATER, LOCATION_HOME):
    '''
    create goal reward maps
    '''
    home_map = np.zeros((num_gridworld_states))
    home_map[LOCATION_HOME] = 1

    water_map = np.zeros((num_gridworld_states))
    water_map[LOCATION_WATER] = 1

    goal_maps = np.array([home_map, water_map])
    return goal_maps


def generate_expert_trajectories(gw, policy, action_list, T):
    '''
    given reward trajectories, generate state-action trajectories, assuming that these agents act optimally
    under the provided reward function
    INPUT:
        gw: gridworld object (containing gridworld information and reward information)
        time_invariant_rewards: (num_states, 1): actual reward on the map
        policy: array of size (num_states**2, num_actions**2)
        T: length of trajectory to generate
    OUTPUT:
        trajectory  - a list of Steps representing an episode
    '''

    cur_state = gw.get_current_state() #current state
    states = [gw.pos2idx1(cur_state)] # save states in their rolled out 1d rep from 1 to num_states
    states4d = [cur_state] # save x,y coordinates of states
    actions = []
    actions2d = []
    rewards = []
    for t in range(0, T-1):
        action_idx = np.random.choice(range(policy.shape[1]), p=policy[gw.pos2idx1(cur_state)]) # take action
        cur_state, action, next_state, reward, is_done = gw.step(action_list[action_idx]) # update current state
        states.append(gw.pos2idx1(next_state))
        states4d.append(next_state)
        actions2d.append(action)
        actions.append(action_idx)
        rewards.append(reward)
        if is_done: break
    # create a trajectory dict
    traj = {'states': np.array(states), 'states4d': np.array(states4d),
            'actions': np.array(actions), 'actions2d': np.array(actions2d),
            'rewards': np.array(rewards)}
    return traj



def main(gridworld_H, gridworld_W, N_experts, T, reward_strength,
         LOCATION_WATER, LOCATION_HOME, VERSION = 1, TAG1=0, TAG2=0, GAMMA=0.9, plot_data=False):
    
    tags = ['selfish', 'perfect_prediction', 'uniform_prediction']

    N_STATES, N_ACTIONS = gridworld_H*gridworld_W, 5

    num_maps = 2
    # sigmas = 2**-(3.5) # uniform covariance for both reward locations
    sigmas = 0.01
    goal_maps = create_goal_maps(gridworld_H*gridworld_W, LOCATION_WATER, LOCATION_HOME) # num_maps x num_states
    weights = np.ones((1,num_maps)) # 1 x num_maps

    # now obtain reward maps
    rewards = weights@goal_maps*reward_strength # 1 x num_states
    rewards = rewards.T # num_states x 1

    r_map = np.reshape(np.array(rewards[:,0]), (gridworld_H,gridworld_W), order='F')

    gw = gridworld.GridWorld(r_map, {},)
    action_list = [(i,j) for i in range(5) for j in range(5)]
    P_a = gw.get_permutation_mat(action_list)
    reward_joint = np.zeros(((gridworld_H*gridworld_W)**2,1))
    for i in range(reward_joint.shape[0]):
        cur = gw.idx12pos(i)
        if cur[0] == cur[2] and cur[1] == cur[3]:
            reward_joint[i] = r_map[cur[0], cur[1]] 
    
    reward_idx = np.argwhere(reward_joint[:,0]!=0)[:,0]
    terminal_states = [gw.idx12pos(r) for r in reward_idx]
    
    if VERSION == 1 or VERSION == 3: # add terminal states
        gw = gridworld.GridWorld(r_map, terminal_states)

    gamma = 0.9
    if TAG1 ==0:
        _,_,p1,_ = two_agent_value_iteration_selfish(P_a.astype(int), reward_joint, gamma)
    elif TAG1 == 1:
        _, _, p1, _ = two_agent_value_iteration_independent_control(P_a.astype(int), reward_joint, gamma)
    elif TAG1 == 2:
        _, _, p1, _ = two_agent_value_iteration_independent_control_uniform_prediction(P_a.astype(int), reward_joint, gamma)
 
    if TAG2 == 0:
        _,_,_,p2 = two_agent_value_iteration_selfish(P_a.astype(int), reward_joint, gamma)
    elif TAG2 == 1:
        _, _, _, p2 = two_agent_value_iteration_independent_control(P_a.astype(int), reward_joint, gamma)
    elif TAG2 == 2:
        _, _, _, p2 = two_agent_value_iteration_independent_control_uniform_prediction(P_a.astype(int), reward_joint, gamma)

    policy = np.zeros((p1.shape[0], p1.shape[1]*p2.shape[1]))
    for i, a in enumerate(action_list):
        policy[:,i] = p1[:,a[0]]*p2[:,a[1]]

    trajs_all_experts = []
    success = 0
    for expert in range(N_experts):
        start_pos = (np.random.randint(0,gw.height), np.random.randint(0,gw.width), np.random.randint(0,gw.height), np.random.randint(0,gw.width))
        if VERSION == 6:
            start_pos = (int(gw.height/2), int(gw.width/2), int(gw.height/2), int(gw.width/2))
        gw.reset(start_pos)
        traj = generate_expert_trajectories(gw, policy, action_list, T)
        trajs_all_experts.append(traj)
        if np.sum(traj['rewards'])>0: success += 1

    generative_parameters = {}
    generative_parameters['P_a'] = P_a
    generative_parameters['generative_rewards'] = rewards
    generative_parameters['goal_maps'] = goal_maps
    generative_parameters['time_invariant_weights'] = weights
    generative_parameters['sigmas'] = sigmas
    generative_parameters['policy'] = policy
    generative_parameters['success_rate'] = success / N_experts

    print('Success rate: {}'.format(success/N_experts))

    # now save everything
    save_dir = 'data/simulated_gridworld_data/{}_{}{}'.format(tags[TAG1], tags[TAG2], VERSION)
    
    # check if save_dir exists, else create it
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir, exist_ok = True)

    with open(save_dir+'/expert_trajectories.pickle','wb') as handle:
        pickle.dump(trajs_all_experts, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(save_dir+'/generative_parameters.pickle', 'wb') as handle:
        pickle.dump(generative_parameters, handle,protocol=pickle.HIGHEST_PROTOCOL)
    if plot_data:
        plot_rewards_all(goal_maps, np.tile(weights,(T,1)),
                                    np.tile(rewards,(T,1)),
                                    gridworld_H, gridworld_W,LOCATION_WATER, LOCATION_HOME,
                                    save_name=save_dir+'/generative_rewards.png')
        # plot a few example trajectories:
        for i in [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]:
            traj = trajs_all_experts[i]
            traj1 = np.array([(a,b) for (a,b,c,d) in traj['states4d']])
            traj2 = np.array([(c,d) for (a,b,c,d) in traj['states4d']])
            fig, ax = plt.subplots(1,2, figsize=(10,4))
            plot_gridworld_trajectories(gridworld_H, gridworld_W, {'states2d':traj1}, fig, ax[0])
            plot_gridworld_trajectories(gridworld_H, gridworld_W, {'states2d':traj2}, fig, ax[1])
            plt.tight_layout()
            fig.savefig(save_dir+'/traj{}.png'.format(i))
        



if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='enter environment specifics')
    parser.add_argument('--TAG1', type=int, default=0, help='which interaction type to use for agent 1')
    parser.add_argument('--TAG2', type=int, default=0, help='which interaction type to use for agent 2')
    parser.add_argument('--VERSION', type=int, default=1, help='which grid env to use')

    args = parser.parse_args()
    VERSION = args.VERSION

    gridworld_H, gridworld_W = 5, 5
    if VERSION == 3:
        gridworld_H, gridworld_W = 9, 9
    N_experts = 500 # number of trajectories to generate
    T = 200 # length of each trajectory

    # set locations of home and water port
    LOCATION_WATER = gridworld_H * int(gridworld_W/2)
    LOCATION_HOME = (gridworld_W-1) * gridworld_H + int(gridworld_H/2)
    reward_strength = 2
    if VERSION == 4:
        reward_strength = 1
    if VERSION == 5:
        reward_strength = 5

    main(gridworld_H, gridworld_W, N_experts, T, reward_strength,
         LOCATION_WATER, LOCATION_HOME, VERSION, args.TAG1, args.TAG2, GAMMA=0.9, plot_data=True)
