from Environment import *
import math
from GuidanceMultiStep import *
from Guidance import *
from Navigation import *
from NavigationMultiStep import *

def guidance_eval(task, cvdn_data):
    """
    :param task:
    :param cvdn_data: [(scan_id, start_id, end_id), ...]
    :return: TODO AgentB's route plan should be changed
    """
    task.eval()
    dataset = FullEnvDataset("validate")
    task.agent_b.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    # dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    accum_nav_tgt, accum_nav_pred, accum_mask = [], [], []
    accum_progress = []
    # for step, scan_ids in enumerate(dataloader):
    for step in range(math.ceil(len(cvdn_data) / Param.batch_size)):
        cur_data = cvdn_data[step * Param.batch_size: ] if step * (Param.batch_size + 1) > len(cvdn_data) \
            else cvdn_data[step * Param.batch_size: (step + 1) * Param.batch_size]
        if len(cur_data) != Param.batch_size: cur_data += [cvdn_data[0]] * (Param.batch_size - len(cur_data))
        scan_ids, start_ids, end_ids = zip(*cur_data)
        cand_view_ids = dataset.get_view_ids(scan_ids)
        start_idxes, end_idxes = [], []
        for cur_start_id, cur_end_id, cur_cand_view_ids in zip(start_ids, end_ids, cand_view_ids):
            for idx, cur_cand_id in enumerate(cur_cand_view_ids):
                if cur_cand_id == cur_start_id: start_idxes.append(idx)
                if cur_cand_id == cur_end_id: end_idxes.append(idx)
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, path_a, is_speak = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        final_ids = [cur_path_a_ids[-1] for cur_path_a_ids in path_a_ids]
        distance1 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, start_ids, end_ids)
        distance2 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, final_ids, end_ids)
        diff_dist = [d1 - d2 for d1, d2 in zip(distance1, distance2)]
        accum_progress += diff_dist
        next_a_ids, mask = get_next_route(dataset.graph_paths, scan_ids, path_a_ids, end_ids, is_speak)
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
        accum_mask.append(mask[:, :-1])
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    accum_mask = torch.cat(accum_mask, dim=0)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    nav_score = [torch.mean(temp_nav[:, k][accum_mask[:, k] == 0]) for k in range(temp_nav.shape[1])]
    print("          eval nav acc = {}".format(nav_score), flush=True)
    print("          progress = {}".format(sum(accum_progress)/len(accum_progress)))


def guidance_eval_single(task, cvdn_data):
    task.eval()
    # dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    dataset = FullEnvDataset("train")
    task.agent_b.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    accum_nav_tgt, accum_nav_pred = [], []
    accum_progress = []
    # for step, scan_ids in enumerate(dataloader):
    for step in range(math.ceil(len(cvdn_data) / Param.batch_size)):
        cur_data = cvdn_data[step * Param.batch_size:] if step * (Param.batch_size + 1) > len(cvdn_data) \
            else cvdn_data[step * Param.batch_size: (step + 1) * Param.batch_size]
        if len(cur_data) != Param.batch_size: cur_data += [cvdn_data[0]] * (Param.batch_size - len(cur_data))
        scan_ids, start_ids, end_ids = zip(*cur_data)
        cand_view_ids = dataset.get_view_ids(scan_ids)
        # temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        # start_idxes, end_idxes = zip(*temp)
        # end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        start_idxes, end_idxes = [], []
        for cur_start_id, cur_end_id, cur_cand_view_ids in zip(start_ids, end_ids, cand_view_ids):
            for idx, cur_cand_id in enumerate(cur_cand_view_ids):
                if cur_cand_id == cur_start_id: start_idxes.append(idx)
                if cur_cand_id == cur_end_id: end_idxes.append(idx)
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, path_a = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        final_ids = [cur_path_a_ids[-1] for cur_path_a_ids in path_a_ids]
        distance1 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, start_ids, end_ids)
        distance2 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, final_ids, end_ids)
        diff_dist = [d1 - d2 for d1, d2 in zip(distance1, distance2)]
        accum_progress += diff_dist
        next_a_ids = [[dataset.graph_paths[cur_scan_id][(cur_id, tgt_id)]["next"] for cur_id in cur_ids]
                      for cur_scan_id, cur_ids, tgt_id in zip(scan_ids, path_a_ids, end_ids)]
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    print("          eval nav acc = {}".format(torch.mean(temp_nav, dim=0)), flush=True)
    print("          progress = {}".format(sum(accum_progress) / len(accum_progress)))


def navigation_eval(task, cvdn_data):
    task.eval()
    # dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    dataset = FullEnvDataset("train")
    task.agent_b.load_obs_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_b.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    accum_loc_tgt, accum_loc_pred = [], []
    accum_nav_tgt, accum_nav_pred = [], []
    accum_path, accum_mask = [], []
    accum_progress = []
    # for step, scan_ids in enumerate(dataloader):
    for step in range(math.ceil(len(cvdn_data) / Param.batch_size)):
        cur_data = cvdn_data[step * Param.batch_size:] if step * (Param.batch_size + 1) > len(cvdn_data) \
            else cvdn_data[step * Param.batch_size: (step + 1) * Param.batch_size]
        if len(cur_data) != Param.batch_size: cur_data += [cvdn_data[0]] * (Param.batch_size - len(cur_data))
        scan_ids, start_ids, end_ids = zip(*cur_data)
        cand_view_ids = dataset.get_view_ids(scan_ids)
        # temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        # start_idxes, end_idxes = zip(*temp)
        # end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        start_idxes, end_idxes = [], []
        for cur_start_id, cur_end_id, cur_cand_view_ids in zip(start_ids, end_ids, cand_view_ids):
            for idx, cur_cand_id in enumerate(cur_cand_view_ids):
                if cur_cand_id == cur_start_id: start_idxes.append(idx)
                if cur_cand_id == cur_end_id: end_idxes.append(idx)
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, path_a, guess_b, is_speak = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        final_ids = [cur_path_a_ids[-1] for cur_path_a_ids in path_a_ids]
        distance1 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, start_ids, end_ids)
        distance2 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, final_ids, end_ids)
        diff_dist = [d1 - d2 for d1, d2 in zip(distance1, distance2)]
        accum_progress += diff_dist
        next_a_ids, mask = get_next_route(dataset.graph_paths, scan_ids, path_a_ids, end_ids, is_speak)
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
        accum_loc_tgt.append(path_a[:, 1:]); accum_loc_pred.append(guess_b)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
        accum_path.append(path_a); accum_mask.append(mask)
    accum_loc_tgt, accum_loc_pred = torch.cat(accum_loc_tgt, dim=0), torch.cat(accum_loc_pred, dim=0)
    accum_nav_tgt, accum_nav_pred = torch.cat(accum_nav_tgt, dim=0), torch.cat(accum_nav_pred, dim=0)
    accum_path = torch.cat(accum_path, dim=0)
    accum_mask = torch.cat(accum_mask, dim=0)
    temp_loc = torch.zeros_like(accum_loc_tgt).to(torch.float32)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_loc[accum_loc_pred == accum_loc_tgt] = 1.0
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    nav_score_mask = torch.stack([torch.mean(temp_nav[:, k][accum_mask[:, k] == 0.0]) for k in range(temp_nav.shape[1])])
    print("          eval loc acc = {}, nav acc = {}, nav acc (mask) = {}"
          .format(torch.mean(temp_loc, dim=0), torch.mean(temp_nav, dim=0), nav_score_mask), flush=True)
    print("          progress = {}".format(sum(accum_progress)/len(accum_progress)), flush=True)


def navigation_eval_single(task, cvdn_data):
    task.eval()
    # dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    dataset = FullEnvDataset("validate")
    task.agent_b.load_obs_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_b.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    accum_loc_tgt, accum_loc_pred = [], []
    accum_nav_tgt, accum_nav_pred = [], []
    accum_progress = []
    # for step, scan_ids in enumerate(dataloader):
    for step in range(math.ceil(len(cvdn_data) / Param.batch_size)):
        cur_data = cvdn_data[step * Param.batch_size:] if step * (Param.batch_size + 1) > len(cvdn_data) \
            else cvdn_data[step * Param.batch_size: (step + 1) * Param.batch_size]
        if len(cur_data) != Param.batch_size: cur_data += [cvdn_data[0]] * (Param.batch_size - len(cur_data))
        scan_ids, start_ids, end_ids = zip(*cur_data)
        cand_view_ids = dataset.get_view_ids(scan_ids)
        # temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        # start_idxes, end_idxes = zip(*temp)
        # end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        start_idxes, end_idxes = [], []
        for cur_start_id, cur_end_id, cur_cand_view_ids in zip(start_ids, end_ids, cand_view_ids):
            for idx, cur_cand_id in enumerate(cur_cand_view_ids):
                if cur_cand_id == cur_start_id: start_idxes.append(idx)
                if cur_cand_id == cur_end_id: end_idxes.append(idx)
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, path_a, guess_b = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        final_ids = [cur_path_a_ids[-1] for cur_path_a_ids in path_a_ids]
        distance1 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, start_ids, end_ids)
        distance2 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, final_ids, end_ids)
        diff_dist = [d1 - d2 for d1, d2 in zip(distance1, distance2)]
        accum_progress += diff_dist
        next_a_ids = [[dataset.graph_paths[cur_scan_id][(cur_id, tgt_id)]['next'] for cur_id in cur_ids]
                      for cur_scan_id, cur_ids, tgt_id in zip(scan_ids, path_a_ids, end_ids)]
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
        accum_loc_tgt.append(path_a[:, 1:]); accum_loc_pred.append(guess_b)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    accum_loc_pred = torch.cat(accum_loc_pred, dim=0)
    accum_loc_tgt = torch.cat(accum_loc_tgt, dim=0)
    temp_loc = torch.zeros_like(accum_loc_tgt).to(torch.float32)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_loc[accum_loc_pred == accum_loc_tgt] = 1.0
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    print("          eval loc acc = {}, nav acc = {}".format(torch.mean(temp_loc, dim=0),
                                                             torch.mean(temp_nav, dim=0)), flush=True)
    print("          progress = {}".format(sum(accum_progress) / len(accum_progress)), flush=True)


def get_next_route(graph_paths, scan_ids, path_a_ids, end_ids, is_speak):
    is_speak_list = is_speak.tolist()
    next_route, mask = [], []
    for cur_scan_id, cur_episode_ids, tgt_id, cur_episode_is_speak in zip(scan_ids, path_a_ids, end_ids, is_speak_list):
        cur_path = graph_paths[cur_scan_id]
        next_route.append([]); mask.append([])
        for j, cur_id in enumerate(cur_episode_ids):
            if len(next_route[-1]) == 0 or (next_route[-1][-1] == cur_id and mask[-1][-1] == 0) or cur_episode_is_speak[j - 1] == 1:  # last step move right
                mask[-1].append(0)
            elif cur_id == cur_episode_ids[j - 1] and mask[-1][-1] == 0:  # last step stop
                mask[-1].append(0)
            else:  # last step move wrong
                mask[-1].append(1)
            next_route[-1].append(cur_path[(cur_id, tgt_id)]["next"])
    mask = torch.from_numpy(np.array(mask))
    return next_route, mask


def get_shortest_path_distance(graph_paths, graph, scan_ids, start_ids, end_ids):
    is_all_done = False
    cur_ids = start_ids
    distances = [0 for _ in range(len(start_ids))]
    paths = [[] for _ in range(len(start_ids))]
    while is_all_done is False:
        next_ids = []
        is_all_done = True
        for i, (scan_id, cur_id, end_id) in enumerate(zip(scan_ids, cur_ids, end_ids)):
            paths[i].append(cur_id)
            if cur_id != end_id:
                next_id = graph_paths[scan_id][(cur_id, end_id)]["next"]
                distances[i] += graph[scan_id][cur_id][next_id][0][3]
                is_all_done = False
            else:
                next_id = end_id
            next_ids.append(next_id)
        cur_ids = next_ids
    # for scan_id, path in zip(scan_ids, paths):
    #     print("{} -> {}".format(scan_id, path), flush=True)
    return distances


def load_cvdn_data(file_path):
    scan_ids, start_ids, end_ids = [], [], []
    with open(file_path, 'r') as f:
        data = json.load(f)
        for sample in data:
            scan_ids.append(sample["scan"])
            start_ids.append(sample["planner_nav_steps"][0])
            end_ids.append(sample["planner_nav_steps"][-1])
    return [(i, j, k) for i, j, k in zip(scan_ids, start_ids, end_ids)]

