


from cacot.model import *
from cacot.dataprep import getID2Pos
from cacot.pretrain_decayer import pretrain_cacot_decayer
from w import *



# %matplotlib inline

CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available(): torch.cuda.set_device(bestGPU(0))


def init_q_network_newyork():
    dic_traffic_env_conf = {'ACTION_PATTERN': 'set', 'NUM_INTERSECTIONS': 196, 'MIN_ACTION_TIME': 10, 'YELLOW_TIME': 5, 'ALL_RED_TIME': 0, 'NUM_PHASES': 2, 'NUM_LANES': 1, 'ACTION_DIM': 2, 'MEASURE_TIME': 10, 'IF_GUI': False, 'DEBUG': False, 'INTERVAL': 1, 'THREADNUM': 8, 'SAVEREPLAY': False, 'RLTRAFFICLIGHT': True, 'DIC_FEATURE_DIM': {'D_LANE_QUEUE_LENGTH': (4,), 'D_LANE_NUM_VEHICLE': (4,), 'D_COMING_VEHICLE': (12,), 'D_LEAVING_VEHICLE': (12,), 'D_LANE_NUM_VEHICLE_BEEN_STOPPED_THRES1': (4,), 'D_CUR_PHASE': (8,), 'D_NEXT_PHASE': (1,), 'D_TIME_THIS_PHASE': (1,), 'D_TERMINAL': (1,), 'D_LANE_SUM_WAITING_TIME': (4,), 'D_VEHICLE_POSITION_IMG': (4, 60), 'D_VEHICLE_SPEED_IMG': (4, 60), 'D_VEHICLE_WAITING_TIME_IMG': (4, 60), 'D_PRESSURE': (1,), 'D_ADJACENCY_MATRIX': (5,), 'D_ADJACENCY_MATRIX_LANE': (5,), 'D_CUR_PHASE_0': (1,), 'D_LANE_NUM_VEHICLE_0': (4,), 'D_CUR_PHASE_1': (1,), 'D_LANE_NUM_VEHICLE_1': (4,), 'D_CUR_PHASE_2': (1,), 'D_LANE_NUM_VEHICLE_2': (4,), 'D_CUR_PHASE_3': (1,), 'D_LANE_NUM_VEHICLE_3': (4,)}, 'LIST_STATE_FEATURE': ['cur_time', 'cur_phase', 'lane_num_vehicle', 'adjacency_matrix', 'adjacency_matrix_lane'], 'DIC_REWARD_INFO': {'flickering': 0, 'sum_lane_queue_length': 0, 'sum_lane_wait_time': 0, 'sum_lane_num_vehicle_left': 0, 'sum_duration_vehicle_left': 0, 'sum_num_vehicle_been_stopped_thres01': 0, 'sum_num_vehicle_been_stopped_thres1': -0.25, 'pressure': 0}, 'LANE_NUM': {'LEFT': 1, 'RIGHT': 1, 'STRAIGHT': 1}, 'PHASE': {'sumo': {0: [0, 1, 0, 1, 0, 0, 0, 0], 1: [0, 0, 0, 0, 0, 1, 0, 1], 2: [1, 0, 1, 0, 0, 0, 0, 0], 3: [0, 0, 0, 0, 1, 0, 1, 0]}, 'anon': {1: [0, 1, 0, 1, 0, 0, 0, 0], 2: [0, 0, 0, 0, 0, 1, 0, 1], 3: [1, 0, 1, 0, 0, 0, 0, 0], 4: [0, 0, 0, 0, 1, 0, 1, 0]}}, 'USE_LANE_ADJACENCY': True, 'ONE_MODEL': False, 'NUM_AGENTS': 1, 'TOP_K_ADJACENCY': 5, 'ADJACENCY_BY_CONNECTION_OR_GEO': False, 'TOP_K_ADJACENCY_LANE': 5, 'SIMULATOR_TYPE': 'anon', 'BINARY_PHASE_EXPANSION': True, 'FAST_COMPUTE': True, 'NEIGHBOR': False, 'MODEL_NAME': 'CoLight', 'NUM_ROW': 7, 'NUM_COL': 28, 'TRAFFIC_FILE': 'anon_28_7_newyork_real_double.json', 'VOLUME': 'newyork', 'ROADNET_FILE': 'roadnet_28_7.json', 'phase_expansion': {1: [0, 1, 0, 1, 0, 0, 0, 0], 2: [0, 0, 0, 0, 0, 1, 0, 1], 3: [1, 0, 1, 0, 0, 0, 0, 0], 4: [0, 0, 0, 0, 1, 0, 1, 0], 5: [1, 1, 0, 0, 0, 0, 0, 0], 6: [0, 0, 1, 1, 0, 0, 0, 0], 7: [0, 0, 0, 0, 0, 0, 1, 1], 8: [0, 0, 0, 0, 1, 1, 0, 0]}, 'phase_expansion_4_lane': {1: [1, 1, 0, 0], 2: [0, 0, 1, 1]}}
    
    
    mypath = 'data/NewYork/28_7/roadnet*'
    roadnet_file = glob(mypath)
    assert len(roadnet_file)==1
    ID2Pos = getID2Pos(roadnet_file[0], dic_traffic_env_conf)
    N_nodes = 28*7
    
        
    net = CausalityConeTransformer(ID2Pos, N_nodes).to(DEVICE)

    cone = net.cone
    load_cone(cone, which_roadmap)

    v1 = cone.causDecay.vFunc_o              # [12,20,20,1], prep_str='/100'      ›› init ~= 2 m/s
    v2 = cone.causDecay.vFunc_d              # [12,20,20,1], prep_str='/100'      ›› init ~= 2 m/s
    lut1 = cone.causDecay.speedStLUT         # (N_nodes, 1, N_nodes, 1) ›› init ~= 2 m/s
    lut2 = cone.attnStLUT                    # (N_nodes, 1, N_nodes, 1) ›› init ~= 0
    m1 = cone.causDecay.decayFun.mlp         # [1,20,20,1], prep_str='/1e4'           √ 
    m2 = cone.timeDecay.mlp                  # [1,20,20,1]               √
    
    return net, cone

    
def init_q_network_6x6():
    dic_traffic_env_conf = {'ACTION_PATTERN': 'set', 'NUM_INTERSECTIONS': 36, 'MIN_ACTION_TIME': 10, 'YELLOW_TIME': 5, 'ALL_RED_TIME': 0, 'NUM_PHASES': 2, 'NUM_LANES': 1, 'ACTION_DIM': 2, 'MEASURE_TIME': 10, 'IF_GUI': False, 'DEBUG': False, 'INTERVAL': 1, 'THREADNUM': 8, 'SAVEREPLAY': False, 'RLTRAFFICLIGHT': True, 'DIC_FEATURE_DIM': {'D_LANE_QUEUE_LENGTH': (4,), 'D_LANE_NUM_VEHICLE': (4,), 'D_COMING_VEHICLE': (12,), 'D_LEAVING_VEHICLE': (12,), 'D_LANE_NUM_VEHICLE_BEEN_STOPPED_THRES1': (4,), 'D_CUR_PHASE': (8,), 'D_NEXT_PHASE': (1,), 'D_TIME_THIS_PHASE': (1,), 'D_TERMINAL': (1,), 'D_LANE_SUM_WAITING_TIME': (4,), 'D_VEHICLE_POSITION_IMG': (4, 60), 'D_VEHICLE_SPEED_IMG': (4, 60), 'D_VEHICLE_WAITING_TIME_IMG': (4, 60), 'D_PRESSURE': (1,), 'D_ADJACENCY_MATRIX': (5,), 'D_ADJACENCY_MATRIX_LANE': (5,), 'D_CUR_PHASE_0': (1,), 'D_LANE_NUM_VEHICLE_0': (4,), 'D_CUR_PHASE_1': (1,), 'D_LANE_NUM_VEHICLE_1': (4,), 'D_CUR_PHASE_2': (1,), 'D_LANE_NUM_VEHICLE_2': (4,), 'D_CUR_PHASE_3': (1,), 'D_LANE_NUM_VEHICLE_3': (4,)}, 'LIST_STATE_FEATURE': ['cur_time', 'cur_phase', 'lane_num_vehicle', 'adjacency_matrix', 'adjacency_matrix_lane'], 'DIC_REWARD_INFO': {'flickering': 0, 'sum_lane_queue_length': 0, 'sum_lane_wait_time': 0, 'sum_lane_num_vehicle_left': 0, 'sum_duration_vehicle_left': 0, 'sum_num_vehicle_been_stopped_thres01': 0, 'sum_num_vehicle_been_stopped_thres1': -0.25, 'pressure': 0}, 'LANE_NUM': {'LEFT': 1, 'RIGHT': 1, 'STRAIGHT': 1}, 'PHASE': {'sumo': {0: [0, 1, 0, 1, 0, 0, 0, 0], 1: [0, 0, 0, 0, 0, 1, 0, 1], 2: [1, 0, 1, 0, 0, 0, 0, 0], 3: [0, 0, 0, 0, 1, 0, 1, 0]}, 'anon': {1: [0, 1, 0, 1, 0, 0, 0, 0], 2: [0, 0, 0, 0, 0, 1, 0, 1], 3: [1, 0, 1, 0, 0, 0, 0, 0], 4: [0, 0, 0, 0, 1, 0, 1, 0]}}, 'USE_LANE_ADJACENCY': True, 'ONE_MODEL': False, 'NUM_AGENTS': 1, 'TOP_K_ADJACENCY': 5, 'ADJACENCY_BY_CONNECTION_OR_GEO': False, 'TOP_K_ADJACENCY_LANE': 5, 'SIMULATOR_TYPE': 'anon', 'BINARY_PHASE_EXPANSION': True, 'FAST_COMPUTE': True, 'NEIGHBOR': False, 'MODEL_NAME': 'CoLight', 'NUM_ROW': 6, 'NUM_COL': 6, 'TRAFFIC_FILE': 'anon_6_6_300_0.3_bi.json', 'VOLUME': '300', 'ROADNET_FILE': 'roadnet_6_6.json', 'phase_expansion': {1: [0, 1, 0, 1, 0, 0, 0, 0], 2: [0, 0, 0, 0, 0, 1, 0, 1], 3: [1, 0, 1, 0, 0, 0, 0, 0], 4: [0, 0, 0, 0, 1, 0, 1, 0], 5: [1, 1, 0, 0, 0, 0, 0, 0], 6: [0, 0, 1, 1, 0, 0, 0, 0], 7: [0, 0, 0, 0, 0, 0, 1, 1], 8: [0, 0, 0, 0, 1, 1, 0, 0]}, 'phase_expansion_4_lane': {1: [1, 1, 0, 0], 2: [0, 0, 1, 1]}}
    
    mypath = 'data/template_lsr/6_6/roadnet*'
    roadnet_file = glob(mypath)
    assert len(roadnet_file)==1
    ID2Pos = getID2Pos(roadnet_file[0], dic_traffic_env_conf)
    N_nodes = 6*6
    
    net = CausalityConeTransformer(ID2Pos, N_nodes).to(DEVICE)

    cone = net.cone
    load_cone(cone, which_roadmap)

    
    v1 = cone.causDecay.vFunc_o              # [12,20,20,1], prep_str='/100'      ›› init ~= 2 m/s
    v2 = cone.causDecay.vFunc_d              # [12,20,20,1], prep_str='/100'      ›› init ~= 2 m/s
    lut1 = cone.causDecay.speedStLUT         # (N_nodes, 1, N_nodes, 1) ›› init ~= 2 m/s
    lut2 = cone.attnStLUT                    # (N_nodes, 1, N_nodes, 1) ›› init ~= 0
    m1 = cone.causDecay.decayFun.mlp         # [1,20,20,1], prep_str='/1e4'           √ 
    m2 = cone.timeDecay.mlp                  # [1,20,20,1]               √
    
    return net, cone









def pretrain_cone(which_roadmap):

    # # === pretrain causDecay m1 ===
    # if which_roadmap=='newyork':
    #     height, width = 7, 28
    #     lohi = [-1e4,1e4]
    # elif which_roadmap=='6x6':
    #     height, width = 6, 6
    #     lohi = [-3000,3000]
    # dname = f'pre_m1@{which_roadmap}'
    # pretrain_cacot_decayer(dname, lo=lohi[0],hi=lohi[1],my_type=f'causDecay_{which_roadmap}',n_feature=1,prep_str='/1e4', N_epochs=1000, lr = 0.02,balance=10)
    
    # raise

    # === pretrain timeDecay m2 ===
    dname = f'pre_m2@{which_roadmap}'
    pretrain_cacot_decayer(dname, lo=-12,hi=12,my_type='timeDecay',n_feature=1, N_epochs=1500, lr = 0.001, prep_str='', balance=-20)
    # pretrain_cacot_decayer(dname, lo=-12,hi=12,my_type='const_0',n_feature=1, N_epochs=1500, lr = 0.001, prep_str='', balance=-20)

    raise
    # === pretrain v1 ===
    dname = f'pre_v1@{which_roadmap}'
    pretrain_cacot_decayer(dname, lo=0,hi=1000,my_type='const_2',n_feature=12, prep_str='/100', lr = 0.06)
    
    # === pretrain v2 ===
    dname = f'pre_v2@{which_roadmap}'
    pretrain_cacot_decayer(dname, lo=0,hi=1000,my_type='const_2',n_feature=12, prep_str='/100', lr = 0.06)
    

    return


def tt_2_Tw(attn_cone):
    # input shape: [num_tokens, num_tokens]
    # output shape: [Tmax, width]
    # == find the token to play with ===
    tokens_t0 = attn_cone[:N_nodes,:]  # [N_nodes, num_tokens]
    the_right_token = tokens_t0[height//2*width+width//2 , :]  # [num_tokens]
    centerLine = select_center_line(the_right_token)
    return centerLine
    


def select_center_line(tokens):
    # tokens: [num_tokens,] = [Tmax*height*width]
    # function is to return subset in the 1-D input vector, that excludes nodes that is not in the center line
    # centerLine: [Tmax, width]
    tokens = tokens.reshape(Tmax, height, width)
    idxCenterLine = height//2
    centerLine = tokens[: , height//2 , :]
    
    return centerLine


def plot_attention_heatMap():
    # attn_cone: [batch_size, n_head, num_tokens, num_tokens]
    # attn_normal: [n_blocks, batch_size, n_head, num_tokens, num_tokens]

    attn_cone = np.load('cacot/attn_cone_1.npy')
    ac = attn_cone = attn_cone[0,0]      #  [num_tokens, num_tokens]

    
    
    p(ac.shape)  #   [num_tokens, num_tokens]
    centerLine = tt_2_Tw(attn_cone)
    
    plot_heatMap(attn_cone)
    
    figure()
    
    plot_heatMap(centerLine)
    plt.title('Attention at Center Line')
    
    return
    


def plot_attention_ingredients():
    # in cacot/model.py, search and enable: plot_attention_ingredients
    epsilon_tt = np.load('cacot/epsilon.npy-bkup-0523')[0] # [t,t]
    timeDiff_tt = np.load('cacot/timeDiff.npy-bkup-0523')[0] # [t,t]
    posDiff_tt = np.load('cacot/posDiff.npy-bkup-0523')[0] # [t,t]
    
    eps_Tw = tt_2_Tw(epsilon_tt)
    timeDiff_Tw = tt_2_Tw(timeDiff_tt)
    posDiff_Tw = tt_2_Tw(posDiff_tt)
    
    figure()
    plot_heatMap(eps_Tw)
    plt.title('eps')
    
    
    figure()
    plot_heatMap(timeDiff_Tw)
    plt.title('time diff')
    
    figure()
    plot_heatMap(posDiff_Tw)
    plt.title('pos diff')
    
    return


def plot_attenCone_fromEps_alsoSave():
    # in cacot/model.py, search and enable: plot_attention_ingredients
    epsilon_tt = np.load('cacot/epsilon.npy-bkup-0523')[0] # [t,t]
    
    eps_Tw = tt_2_Tw(epsilon_tt)

    m1 = cone.causDecay.decayFun.mlp

    m2 = cone.timeDecay.mlp

    m1 = m2

    attn_Tw = m1(torch.tensor(eps_Tw).unsqueeze(-1)).squeeze(-1).data.numpy()
    attnCaus2 = u_postproc(attn_Tw, 'causDecay')

    np.save('cacot/tmp_Tw_causDecay.npy',attnCaus2)
    figure(); plot_heatMap(attnCaus2); plt.title('attnCaus2')
    return





def plot_attenTime_and_attnNormal_alsoSave():
    # in cacot/model.py, search and enable: plot_attention_ingredients
    timeDiff_tt = np.load('cacot/timeDiff.npy-bkup-0523')[0] # [t,t]
    timeDiff_Tw = tt_2_Tw(timeDiff_tt)
    attn_normal = np.load('cacot/attn_normal_1.npy')
    an = attn_normal = attn_normal[:,0,...]  # [block, head, t, t]

    m2 = cone.timeDecay.mlp

    timeDecay_Tw = m2(torch.tensor(timeDiff_Tw).unsqueeze(-1)).squeeze(-1).data.numpy()
    timeDecay_Tw = u_postproc(timeDecay_Tw,'timeDecay')
    

    an_Tw = tt_2_Tw(an[1,0])
    an_Tw = u_postproc(an_Tw,'lut')


    
    timeDecay_2 = timeDecay_Tw + an_Tw
    figure(); plot_heatMap(timeDecay_2); plt.title('timeDecay_Tw')



    an_2 = tt_2_Tw(an[1,1])
    an_2 = u_postproc(an_2,'normal')
    figure(); plot_heatMap(an_2); plt.title('attnNormal_Tw')

    

    np.save('cacot/tmp_Tw_attnNormal.npy',an_2)
    np.save('cacot/tmp_Tw_timeDecay.npy',timeDecay_2)

    return



def u_plot_Tw_attnComb_final(is_0523=1):
    if is_0523:
        attnNormal = np.load('cacot/res0523-TwAttn-normal.npy')
        attnTime = np.load('cacot/res0523-TwAttn-time.npy')
        attnCause = np.load('cacot/res0523-TwAttn-caus.npy')
    else:
        plot_attenCone_fromEps_alsoSave()
        plot_attenTime_and_attnNormal_alsoSave()
        attnNormal = np.load('cacot/tmp_Tw_attnNormal.npy')
        attnTime = np.load('cacot/tmp_Tw_timeDecay.npy')
        attnCause = np.load('cacot/tmp_Tw_causDecay.npy')
    attnA = attnNormal + attnTime + attnCause

    # figure(); plot_heatMap(attnA); plt.title('comb attention')    
    # plt.savefig('cacot/ttAttn-A.pdf', bbox_inches='tight')


    left = 0.1
    wid = 1.2
    hit = 0.42
    h1 = 0.1
    h2 = 0.6

    figure()
    plt.axes([left, h2, wid, hit])
    plot_heatMap(attnCause); plt.title('Cone Decay Prior')
    plt.xticks([])
    plt.yticks([0,5,10], [0,5,10])
    
    plt.axes([left, h1, wid, hit])
    plot_heatMap(attnTime); plt.title('Time Decay Prior + LUT')
    plt.yticks([0,5,10], [0,5,10])
    plt.xticks([0,7,14,21,28], [0,7,14,21,28])
    plt.savefig('cacot/TwAttn-1.pdf', bbox_inches='tight')


    figure()
    plt.axes([left, h2, wid, hit])
    plot_heatMap(attnNormal); plt.title('Residual Normal Attention')
    plt.xticks([])
    plt.yticks([0,5,10], [0,5,10])
    plt.axes([left, h1, wid, hit])
    plot_heatMap(attnA); plt.title('Overall Attention')
    plt.yticks([0,5,10], [0,5,10])
    plt.xticks([0,7,14,21,28], [0,7,14,21,28])
    plt.savefig('cacot/TwAttn-2.pdf', bbox_inches='tight')

    return







    
if __name__ == '__main__':



    


    which_roadmap = ['newyork', '6x6'][0]

    if which_roadmap == 'newyork':
        height, width = 7, 28
        net, cone = init_q_network_newyork()

    elif which_roadmap == '6x6':
        height, width = 6, 6    
        net, cone = init_q_network_6x6()

    N_nodes, Tmax, ID2Pos = net.N_nodes,net.Tmax, net.ID2Pos











    # pretrain_cone(which_roadmap)



    # u_plot_Tw_attnComb_final(1)


    u_plot_tt_A_final(0)



    
    # pretrain_cone()
    





















































































    pass
