import copy
import numpy as np
from UTIL.tensor_ops import my_view, __hash__, repeat_at, gather_righthand
from .foundation import AlgorithmConfig

def random_group(n_thread, hete_type, n_hete_types, n_group, selected_tps):
    n_agent = hete_type.shape[-1]
    res = np.zeros(shape=(n_thread, n_agent), dtype=int)
    gp_sel_summary = []
    for i in range(n_thread):
        low_group = 1 if AlgorithmConfig.hete_sel_exclude_frontend else 0
        group_assignment = np.random.randint(low=low_group, high=n_group, size=(n_hete_types))
        group_assignment[selected_tps[i]]=0
        gp_sel_summary.append(copy.deepcopy(group_assignment))
        for ht, group in enumerate(group_assignment):
            mask = (hete_type == ht) # bool mask, find ht type agents, 1D=[n_agent,]
            res[i,mask] = group
    return res, np.stack(gp_sel_summary)

# select_nets_for_shellenv(n_types=n_types, 
#                          policy=self.RL_functional,
#                          hete_type_list=self.hete_type,
#                          n_thread = n_thread,
#                          n_gp=AlgorithmConfig.hete_n_net_placeholder
#                          )   

def select_nets_for_shellenv(n_types, policy, hete_type_list, n_thread, n_gp, testing):
    # choose one hete type
    n_alive_frontend = AlgorithmConfig.hete_n_alive_frontend


    tmp = np.arange(n_types)
    selected_types = np.stack([
        np.random.choice(
            a=tmp,
            size=(n_alive_frontend),
            replace=False,
            p=None)
        for _ in range(n_thread)
    ])
    if testing: selected_types = np.stack([np.arange(n_types) for _ in range(n_thread)])
    
    # generate a random group selection array
    group_sel_arr, gp_sel_summary = random_group(n_thread=n_thread, hete_type=hete_type_list, n_hete_types=n_types, n_group=n_gp, selected_tps=selected_types)
    # group to net index
    n_tp = n_types
    get_placeholder = lambda type, group: group*n_tp + type
    get_type_group = lambda ph: (ph%n_tp, ph//n_tp)
    hete_type_arr = repeat_at(hete_type_list, 0, n_thread)
    selected_nets = get_placeholder(type=hete_type_arr, group=group_sel_arr)
    

    # # selected_type = np.random.randint(low=0, high=n_types, size=())
    # # select corrisponding agents
    # selected_agent_bool_1d = np.array([(i in selected_types) for i in hete_type_list])
    # # selected_agent_bool = repeat_at(selected_agent_bool_1d, 0, n_thread)
    # # replace chosen agents
    # selected_nets[:, selected_agent_bool_1d] = repeat_at(hete_type_list[selected_agent_bool_1d], 0, n_thread) # (selected_type) #  + group*n_tp (group=0)
    return selected_nets, gp_sel_summary




