from Environment import *
from Param import *
from GuidanceMultiStep import *
from torch.optim import Adam


def guidance_eval(task, dataset):
    task.eval()
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    accum_nav_tgt, accum_nav_pred, accum_mask = [], [], []
    for step, scan_ids in enumerate(dataloader):
        cand_view_ids = dataset.get_view_ids(scan_ids)
        temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        start_idxes, end_idxes = zip(*temp)
        end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, path_a, is_speak = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        next_a_ids, mask = get_next_route(dataset.graph_paths, scan_ids, path_a_ids, end_ids, is_speak)
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
        accum_mask.append(mask[:, :-1])
        print(torch.mean(torch.sum(is_speak.to(torch.float32), dim=1), dim=0), flush=True)
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    accum_mask = torch.cat(accum_mask, dim=0)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    nav_score = [torch.mean(temp_nav[:, k][accum_mask[:, k] == 0]) for k in range(temp_nav.shape[1])]
    print("          eval nav acc = {}".format(nav_score), flush=True)


def get_next_route(graph_paths, scan_ids, path_a_ids, end_ids, is_speak):
    is_speak_list = is_speak.tolist()
    next_route, mask = [], []
    for cur_scan_id, cur_episode_ids, tgt_id, cur_episode_is_speak in zip(scan_ids, path_a_ids, end_ids, is_speak_list):
        cur_path = graph_paths[cur_scan_id]
        next_route.append([]); mask.append([])
        for j, cur_id in enumerate(cur_episode_ids):
            if len(next_route[-1]) == 0 or (next_route[-1][-1] == cur_id and mask[-1][-1] == 0) or cur_episode_is_speak[j - 1] == 1:  # last step move right
                mask[-1].append(0)
            elif cur_id == cur_episode_ids[j - 1] and mask[-1][-1] == 0:  # last step stop
                mask[-1].append(0)
            else:  # last step move wrong
                mask[-1].append(1)
            next_route[-1].append(cur_path[(cur_id, tgt_id)]["next"])
    mask = torch.from_numpy(np.array(mask))
    return next_route, mask

