from smac_plus.starcraft2.starcraft2 import StarCraft2Env
import numpy as np

def get_close(loc,avail):
    '''if abs(loc[0])<0.01 or (abs(loc[0])>abs(loc[1]) and abs(loc[1])>0.01):
        if loc[1]>0:
            return 2
        else:
            return 3
    else:
        if loc[0]>0:
            return 4
        else:
            return 5'''
    if loc[1]>0.01 and avail[2]==1:
        return 2
    elif loc[1]<-0.01 and avail[3]==1:
        return 3
    elif loc[0]>0.01 and avail[4]==1:
        return 4
    elif loc[0]<-0.01 and avail[5]==1:
        return 5
    elif loc[1]>0 and avail[2]==1:
        return 2
    elif loc[1]<0 and avail[3]==1:
        return 3
    elif loc[0]>0 and avail[4]==1:
        return 4
    elif loc[0]<0 and avail[5]==1:
        return 5

def agent(env,n_agents,obs,arrive):
    obs=np.array(obs)
    actions = np.ones(n_agents)

    #observer
    messages=np.zeros((n_agents,n_agents,3))
    enemy=obs[2,4:(4+7*3)].reshape(3,7)[:,2:4]#(3,2)
    ally=obs[2,-(7*3+4):-4].reshape(3,7)[:,2:4]#(3,2)
    loc_2r=enemy.sum(axis=0)/2#(2)
    messages[2,:2,:2]=0-ally[:2]
    messages[2,3,:2]=loc_2r-ally[2]
    #print('messages to 2h',messages[2,:2,:2],'messages to 1r',messages[2,3,:2])

    vis=obs[:,4:(4+7*3)].reshape(4,3,7)[:,:,-3:]#(4,3,3)
    #print('obs2vis',obs[:,4:(4+7*3)].reshape(4,3,7))
    #print('vis',vis[:2,:,1],vis[3,:,0],arrive)
    # 1r vs 2r
    if not arrive[3] and vis[3,:,0].sum()<2:
        actions[3]=get_close(messages[2,3,:2],env.get_avail_agent_actions(3))
    else:
        arrive[3]=True
        if vis[3,0,0]>0:
            actions[3]=6
        elif vis[3,1,0]>0:
            actions[3]=7
        elif vis[3,2,0]>0:
            actions[3]=8
    
    #2h vs 1b
    if not arrive[0] and vis[0,:,1].sum()<1:
        actions[0]=get_close(messages[2,0,:2],env.get_avail_agent_actions(0))
    else:
        arrive[0]=True
        if vis[0,0,1]>0:
            actions[0]=6
        elif vis[0,1,1]>0:
            actions[0]=7
        elif vis[0,2,1]>0:
            actions[0]=8
    if not arrive[1] and vis[1,:,1].sum()<1:
        actions[1]=get_close(messages[2,1,:2],env.get_avail_agent_actions(1))
    else:
        arrive[1]=True
        if vis[1,0,1]>0:
            actions[1]=6
        elif vis[1,1,1]>0:
            actions[1]=7
        elif vis[1,2,1]>0:
            actions[1]=8
    if obs[0,-4]==0:
        actions[0]=0
    if obs[1,-4]==0:
        actions[1]=0
    if obs[3,-4]==0:
        actions[3]=0

    return actions,messages.reshape(4,4*3)

def main():
    env = StarCraft2Env(
    continuing_episode=False,
    difficulty="7",
    #game_version=null,
    #map_name="bane_vs_hM",
    move_amount=2,
    #obs_all_health=True,
    obs_instead_of_state=False,
    obs_last_action=False,
    obs_own_health=True,
    obs_pathing_grid=False,
    obs_terrain_height=False,
    obs_timestep_number=False,
    reward_death_value=10,
    reward_defeat=0,
    reward_negative_scale=0.5,
    reward_only_positive=True,
    reward_scale=True,
    reward_scale_rate=20,
    reward_sparse=False,
    reward_win=200,
    replay_dir="",
    replay_prefix="",
    state_last_action=True,
    state_timestep_number=False,
    step_mul=8,
    debug=False,#True,
    print_rew=False,
    is_print=False,
    print_steps=1000,
    map_name='1o_1r2h_vs_2r1b',#'1o_5b_vs_1h_plain', #'1o_5b_vs_3h',#'1o_8b_vs_2h', #1o_10b_vs_1r #MMM #5z_vs_1ul #1o_10b_vs_1r #bane_vs_bane
    sight_range=6,#2,
    shoot_range=6,#2,
    obs_all_health=False,
    obs_enemy_health=False)
    env_info = env.get_env_info()

    n_actions = env_info["n_actions"]
    n_agents = env_info["n_agents"]

    n=500
    actions_onehot=np.zeros((n,101,4,9))
    acts=np.zeros((n,101,4,1))
    avail_actions=np.zeros((n,101,4,9))
    filled=np.zeros((n,101,1))
    obss=np.zeros((n,101,4,50))
    rewards=np.zeros((n,101,1))
    states=np.zeros((n,101,82))
    terms=np.zeros((n,101,1))
    messages=np.zeros((n,101,4,4*3))

    epi=0
    while epi<n:
        try:
            env.reset()
            terminated = False
            episode_reward = 0
            t=0
            arrive=[False]*4

            while not terminated:
                t+=1
                obs = env.get_obs()
                state = env.get_state()
                # env.render()  # Uncomment for rendering

                actions,message=agent(env,n_agents,obs,arrive)

                reward, terminated, _ = env.step(actions)
                episode_reward += reward
                #break

                actions_onehot[epi,t-1]=np.eye(9)[actions.astype('int')]
                acts[epi,t-1,:,0]=actions
                avail_actions[epi,t-1]=np.array([env.get_avail_agent_actions(i) for i in range(4)])
                filled[epi,t-1]=1
                obss[epi,t-1]=np.array(obs)
                rewards[epi,t-1]=reward
                states[epi,t-1]=state
                messages[epi,t-1]=message

            terms[epi,t-1]=1
            filled[epi,t]=1
            obss[epi,t]=np.array(env.get_obs())
            states[epi,t]=env.get_state()

            print('total steps',t)
            print("Total reward in episode {} = {}".format(epi, episode_reward))
            epi+=1
        except:
            print('bug')

    env.close()
    
    dir='/home/myh/myh/buffer/sc2/1o_1r2h_vs_2r1b_from_start/buffer_9024001/'
    np.save(dir+'actions_onehot.npy',actions_onehot.astype(np.float32))
    np.save(dir+'actions.npy',acts.astype(np.int))
    np.save(dir+'avail_actions.npy',avail_actions.astype(np.int))
    np.save(dir+'filled.npy',filled.astype(np.int))
    np.save(dir+'obs.npy',obss.astype(np.float32))
    np.save(dir+'reward.npy',rewards.astype(np.float32))
    np.save(dir+'state.npy',states.astype(np.float32))
    np.save(dir+'terminated.npy',terms.astype(np.int))
    np.save(dir+'messages.npy',messages.astype(np.float32))

main()