from smac_plus.starcraft2.starcraft2 import StarCraft2Env
import numpy as np

globol={'last_messages':np.zeros((7,7,3))}

def agent(env,n_agents,obs,arrive):
    obs=np.array(obs)
    actions = np.ones(n_agents)

    #observer
    messages=np.zeros((n_agents,n_agents,3))
    enemy=obs[1:,4]#(6)
    ally=obs[1:,-(6*6+3):-3].reshape(6,6,6)[:,:,0]#(6,6)
    #print(ally)
    ally_loc=obs[0,-(6*6+3):-3].reshape(6,6)[:,2]#(6)#>0 right <0 left
    seen=[0]*6
    for i in range(1,7):
        if globol['last_messages'][:,i,0].sum()>0 or enemy[i-1]>0:
            seen[i-1]=1
            messages[i,:i,0]=ally[i-1,:i]
            messages[i,i+1:,0]=ally[i-1,i:]
            #print(i,'send',messages[i,:,0])
    globol['last_messages']=messages
    #print('seen',seen)
    
    #attack
    if obs[0,4]==1:
        actions[0]=6
        return actions,messages.reshape(n_agents,n_agents*3)
    #decision
    if 'decision' not in globol and messages[:,0,0].sum()>0:
        assert messages[:,0,0].sum()==1
        if ally_loc[messages[1:,0,0]==1].sum()>0:
            globol['decision']='right'
        else:
            globol['decision']='left'
    elif 'decision' in globol:
        if globol['decision']=='down':
            actions[0]=3
        else:
            #print('ally_loc',ally_loc)
            if globol['decision']!='down' and messages[:,0,0].sum()==2 and np.abs(ally_loc.sum()/2)<0.01:
                globol['decision']='down'
            if globol['decision']=='right':
                actions[0]=4
            elif globol['decision']=='left':
                actions[0]=5
            else:
                actions[0]=3
    elif obs[0,-3]==0:
        actions[0]=0

    return actions,messages.reshape(n_agents,n_agents*3)

def main():
    env = StarCraft2Env(
    continuing_episode=False,
    difficulty="7",
    #game_version=null,
    #map_name="bane_vs_hM",
    move_amount=2,
    #obs_all_health=True,
    obs_instead_of_state=False,
    obs_last_action=False,
    obs_own_health=True,
    obs_pathing_grid=False,
    obs_terrain_height=False,
    obs_timestep_number=False,
    reward_death_value=10,
    reward_defeat=0,
    reward_negative_scale=0.5,
    reward_only_positive=True,
    reward_scale=True,
    reward_scale_rate=20,
    reward_sparse=False,
    reward_win=200,
    replay_dir="",
    replay_prefix="",
    state_last_action=True,
    state_timestep_number=False,
    step_mul=8,
    debug=False,#True,
    print_rew=False,
    is_print=False,
    print_steps=1000,
    map_name='6o1b_vs_1r',#'1o_5b_vs_1h_plain', #'1o_5b_vs_3h',#'1o_8b_vs_2h', #1o_10b_vs_1r #MMM #5z_vs_1ul #1o_10b_vs_1r #bane_vs_bane
    sight_range=13,#2,
    shoot_range=13,#2,
    obs_all_health=False,
    obs_enemy_health=False)
    env_info = env.get_env_info()

    n_actions = env_info["n_actions"]
    n_agents = env_info["n_agents"]

    n=500
    actions_onehot=np.zeros((n,51,7,7))
    acts=np.zeros((n,51,7,1))
    avail_actions=np.zeros((n,51,7,7))
    filled=np.zeros((n,51,1))
    obss=np.zeros((n,51,7,49))
    rewards=np.zeros((n,51,1))
    states=np.zeros((n,51,96))
    terms=np.zeros((n,51,1))
    messages=np.zeros((n,51,7,7*3))

    epi=0
    while epi<n:
        #try:
        globol['last_messages']=np.zeros((7,7,3))
        if 'decision' in globol:
            del globol['decision']
        env.reset()
        terminated = False
        episode_reward = 0
        t=0
        arrive=[False]*4

        while not terminated:
            t+=1
            obs = env.get_obs()
            state = env.get_state()
            # env.render()  # Uncomment for rendering

            actions,message=agent(env,n_agents,obs,arrive)
            #print('ac',actions)
            '''actions=[]
            for agent_id in range(n_agents):
                avail_actions = env.get_avail_agent_actions(agent_id)
                avail_actions_ind = np.nonzero(avail_actions)[0]
                actions.append(np.random.choice(avail_actions_ind))'''

            reward, terminated, _ = env.step(actions)
            episode_reward += reward
            #break

            actions_onehot[epi,t-1]=np.eye(7)[actions.astype('int')]
            acts[epi,t-1,:,0]=actions
            avail_actions[epi,t-1]=np.array([env.get_avail_agent_actions(i) for i in range(7)])
            filled[epi,t-1]=1
            obss[epi,t-1]=np.array(obs)
            rewards[epi,t-1]=reward
            states[epi,t-1]=state
            messages[epi,t-1]=message

        terms[epi,t-1]=1
        filled[epi,t]=1
        obss[epi,t]=np.array(env.get_obs())
        states[epi,t]=env.get_state()

        print('total steps',t)
        print("Total reward in episode {} = {}".format(epi, episode_reward))
        epi+=1
        #except:
        #    print('bug')

    env.close()
    
    dir='/home/myh/myh/buffer/sc2/6o1b_vs_1r_from_start/buffer_9026002/'
    np.save(dir+'actions_onehot.npy',actions_onehot.astype(np.float32))
    np.save(dir+'actions.npy',acts.astype(np.int))
    np.save(dir+'avail_actions.npy',avail_actions.astype(np.int))
    np.save(dir+'filled.npy',filled.astype(np.int))
    np.save(dir+'obs.npy',obss.astype(np.float32))
    np.save(dir+'reward.npy',rewards.astype(np.float32))
    np.save(dir+'state.npy',states.astype(np.float32))
    np.save(dir+'terminated.npy',terms.astype(np.int))
    np.save(dir+'messages.npy',messages.astype(np.float32))

main()