import numpy as np
import torch
import json
from GuidanceMultiStep import *
from Guidance import *
from Navigation import *


def trajectory_length(graphs, scan_ids, path_ids):
    res = get_actual_distance(graphs, scan_ids, path_ids)
    return sum(res) / len(res)


def navigation_error(graph_paths, graph, scan_ids, end_ids, tgt_ids):
    distances = get_shortest_path_distance(graph_paths, graph, scan_ids, end_ids, tgt_ids)
    return sum(distances) / len(distances)


def success_rate(graph_paths, graph, scan_ids, end_ids, tgt_ids):
    distances = get_shortest_path_distance(graph_paths, graph, scan_ids, end_ids, tgt_ids)
    distances_arr = torch.from_numpy(np.array(distances))
    is_success = torch.zeros_like(distances_arr).to(torch.float32)
    is_success[distances_arr < 3] = 1.0
    return torch.mean(is_success)


def success_weighted_by_path_length(graph_paths, graph, scan_ids, start_ids, end_ids, tgt_ids, path_ids):
    # success
    distances_e2t = get_shortest_path_distance(graph_paths, graph, scan_ids, end_ids, tgt_ids)
    distances_e2t_arr = torch.from_numpy(np.array(distances_e2t))
    is_success = torch.zeros_like(distances_e2t_arr).to(torch.float32)
    is_success[distances_e2t_arr < 3] = 1.0
    # shortest distances
    distances_s2t = get_shortest_path_distance(graph_paths, graph, scan_ids, start_ids, tgt_ids)
    distances_s2t_arr = torch.from_numpy(np.array(distances_s2t))
    # actual distances
    distances_act = get_actual_distance(graph, scan_ids, path_ids)
    distances_act_arr = torch.from_numpy(np.array(distances_act))
    # calculate
    spl = torch.mean(is_success * distances_s2t_arr / torch.max(distances_act_arr, distances_s2t_arr))
    return spl


def get_shortest_path_distance(graph_paths, graph, scan_ids, end_ids, tgt_ids):
    is_all_done = False
    cur_ids = end_ids
    distances = [0 for _ in range(len(end_ids))]
    paths = [[] for _ in range(len(end_ids))]
    while is_all_done is False:
        next_ids = []
        is_all_done = True
        for i, (scan_id, cur_id, tgt_id) in enumerate(zip(scan_ids, cur_ids, tgt_ids)):
            paths[i].append(cur_id)
            if cur_id != tgt_id:
                next_id = graph_paths[scan_id][(cur_id, tgt_id)]["next"]
                distances[i] += graph[scan_id][cur_id][next_id][0][3]
                is_all_done = False
            else:
                next_id = tgt_id
            next_ids.append(next_id)
        cur_ids = next_ids
    return distances


def get_actual_distance(graph, scan_ids, path_ids):
    res = []
    for cur_scan_id, cur_path_ids in zip(scan_ids, path_ids):
        # print(cur_path_ids)
        res.append(0)
        cur_graph = graph[cur_scan_id]
        last_id = cur_path_ids[0]
        for next_id in cur_path_ids[1:]:
            # print(cur_graph[last_id])
            if last_id != next_id:
                cur_dist = cur_graph[last_id][next_id][0][3]
            else:
                cur_dist = 0.0
            res[-1] += cur_dist
            last_id = next_id
    return res


def load_r2r_data(file_path):
    scan_ids, start_ids, tgt_ids, distance = [], [], [], []
    with open(file_path, 'r') as f:
        temp = json.load(f)
        for sample in temp:
            scan_ids.append(sample['scan'])
            start_ids.append(sample['path'][0])
            tgt_ids.append(sample['path'][-1])
            distance.append(float(sample['distance']))
    # return scan_ids, start_ids, tgt_ids, distance
    return [(i, j, k) for i, j, k in zip(scan_ids, start_ids, tgt_ids)], distance


def guidance_eval(task, data_r2r):
    task.eval()
    # dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    dataset = FullEnvDataset("validate")
    accum_nav_tgt, accum_nav_pred, accum_mask = [], [], []
    accum_scan_ids, accum_start_ids, accum_end_ids, accum_tgt_ids, accum_path_ids = [], [], [], [], []
    task.agent_b.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    print("act emb loaded.", flush=True)
    for step in range(math.ceil(len(data_r2r) / Param.batch_size)):
        cur_data = data_r2r[step * Param.batch_size:] if step * (Param.batch_size + 1) > len(data_r2r) \
            else data_r2r[step * Param.batch_size: (step + 1) * Param.batch_size]
        if len(cur_data) != Param.batch_size: cur_data += [data_r2r[0]] * (Param.batch_size - len(cur_data))
        scan_ids, start_ids, end_ids = zip(*cur_data)
        cand_view_ids = dataset.get_view_ids(scan_ids)
        # temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        # start_idxes, end_idxes = zip(*temp)
        # end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        start_idxes, end_idxes = [], []
        for cur_start_id, cur_end_id, cur_cand_view_ids in zip(start_ids, end_ids, cand_view_ids):
            for idx, cur_cand_id in enumerate(cur_cand_view_ids):
                if cur_cand_id == cur_start_id: start_idxes.append(idx)
                if cur_cand_id == cur_end_id: end_idxes.append(idx)
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, path_a, is_speak = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        next_a_ids, mask = get_next_route(dataset.graph_paths, scan_ids, path_a_ids, end_ids, is_speak)
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
        accum_mask.append(mask[:, :-1])
        accum_start_ids += start_ids; accum_end_ids += [cur_path_ids[-1] for cur_path_ids in path_a_ids]
        accum_tgt_ids += end_ids; accum_scan_ids += scan_ids; accum_path_ids += path_a_ids
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    accum_mask = torch.cat(accum_mask, dim=0)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    nav_score = [torch.mean(temp_nav[:, k][accum_mask[:, k] == 0]) for k in range(temp_nav.shape[1])]
    print("          eval nav acc = {}".format(nav_score), flush=True)
    tl = trajectory_length(dataset.graphs, accum_scan_ids, accum_path_ids)
    ne = navigation_error(dataset.graph_paths, dataset.graphs, accum_scan_ids, accum_end_ids, accum_tgt_ids)
    sr = success_rate(dataset.graph_paths, dataset.graphs, accum_scan_ids, accum_end_ids, accum_tgt_ids)
    spl = success_weighted_by_path_length(dataset.graph_paths, dataset.graphs, accum_scan_ids, accum_start_ids,
                                          accum_end_ids, accum_tgt_ids, accum_path_ids)
    print("tl = {}, ne = {}, sr = {}, spl = {}".format(tl, ne, sr, spl), flush=True)


def get_next_route(graph_paths, scan_ids, path_a_ids, end_ids, is_speak):
    is_speak_list = is_speak.tolist()
    next_route, mask = [], []
    for cur_scan_id, cur_episode_ids, tgt_id, cur_episode_is_speak in zip(scan_ids, path_a_ids, end_ids, is_speak_list):
        cur_path = graph_paths[cur_scan_id]
        next_route.append([]); mask.append([])
        for j, cur_id in enumerate(cur_episode_ids):
            if len(next_route[-1]) == 0 or (next_route[-1][-1] == cur_id and mask[-1][-1] == 0) or cur_episode_is_speak[j - 1] == 1:  # last step move right
                mask[-1].append(0)
            elif cur_id == cur_episode_ids[j - 1] and mask[-1][-1] == 0:  # last step stop
                mask[-1].append(0)
            else:  # last step move wrong
                mask[-1].append(1)
            next_route[-1].append(cur_path[(cur_id, tgt_id)]["next"])
    mask = torch.from_numpy(np.array(mask))
    return next_route, mask


def guidance_eval_single(task, r2r_data):
    task.eval()
    # dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    dataset = FullEnvDataset("validate")
    task.agent_b.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    accum_nav_tgt, accum_nav_pred = [], []
    # accum_progress = []
    accum_scan_ids, accum_start_ids, accum_end_ids, accum_tgt_ids, accum_path_ids = [], [], [], [], []
    # for step, scan_ids in enumerate(dataloader):
    for step in range(math.ceil(len(r2r_data) / Param.batch_size)):
        cur_data = r2r_data[step * Param.batch_size:] if step * (Param.batch_size + 1) > len(r2r_data) \
            else r2r_data[step * Param.batch_size: (step + 1) * Param.batch_size]
        if len(cur_data) != Param.batch_size: cur_data += [r2r_data[0]] * (Param.batch_size - len(cur_data))
        scan_ids, start_ids, end_ids = zip(*cur_data)
        cand_view_ids = dataset.get_view_ids(scan_ids)
        # temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        # start_idxes, end_idxes = zip(*temp)
        # end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        start_idxes, end_idxes = [], []
        for cur_start_id, cur_end_id, cur_cand_view_ids in zip(start_ids, end_ids, cand_view_ids):
            for idx, cur_cand_id in enumerate(cur_cand_view_ids):
                if cur_cand_id == cur_start_id: start_idxes.append(idx)
                if cur_cand_id == cur_end_id: end_idxes.append(idx)
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, path_a = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        # final_ids = [cur_path_a_ids[-1] for cur_path_a_ids in path_a_ids]
        # distance1 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, start_ids, end_ids)
        # distance2 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, final_ids, end_ids)
        # diff_dist = [d1 - d2 for d1, d2 in zip(distance1, distance2)]
        # accum_progress += diff_dist
        next_a_ids = [[dataset.graph_paths[cur_scan_id][(cur_id, tgt_id)]["next"] for cur_id in cur_ids]
                      for cur_scan_id, cur_ids, tgt_id in zip(scan_ids, path_a_ids, end_ids)]
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
        accum_start_ids += start_ids; accum_end_ids += [cur_path_ids[-1] for cur_path_ids in path_a_ids]
        accum_tgt_ids += end_ids; accum_scan_ids += scan_ids; accum_path_ids += path_a_ids
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    tl = trajectory_length(dataset.graphs, accum_scan_ids, accum_path_ids)
    ne = navigation_error(dataset.graph_paths, dataset.graphs, accum_scan_ids, accum_end_ids, accum_tgt_ids)
    sr = success_rate(dataset.graph_paths, dataset.graphs, accum_scan_ids, accum_end_ids, accum_tgt_ids)
    spl = success_weighted_by_path_length(dataset.graph_paths, dataset.graphs, accum_scan_ids, accum_start_ids,
                                          accum_end_ids, accum_tgt_ids, accum_path_ids)
    print("          eval nav acc = {}".format(torch.mean(temp_nav, dim=0)), flush=True)
    # print("          progress = {}".format(sum(accum_progress) / len(accum_progress)))
    print("tl = {}, ne = {}, sr = {}, spl = {}".format(tl, ne, sr, spl), flush=True)


def navigation_eval_single(task, r2r_data):
    task.eval()
    # dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    dataset = FullEnvDataset("validate")
    task.agent_b.load_obs_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_b.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset, dataset.scan_ids, dataset.scan_ids)
    accum_loc_tgt, accum_loc_pred = [], []
    accum_nav_tgt, accum_nav_pred = [], []
    # accum_progress = []
    accum_scan_ids, accum_start_ids, accum_end_ids, accum_tgt_ids, accum_path_ids = [], [], [], [], []
    # for step, scan_ids in enumerate(dataloader):
    for step in range(math.ceil(len(r2r_data) / Param.batch_size)):
        cur_data = r2r_data[step * Param.batch_size:] if step * (Param.batch_size + 1) > len(r2r_data) \
            else r2r_data[step * Param.batch_size: (step + 1) * Param.batch_size]
        if len(cur_data) != Param.batch_size: cur_data += [r2r_data[0]] * (Param.batch_size - len(cur_data))
        scan_ids, start_ids, end_ids = zip(*cur_data)
        cand_view_ids = dataset.get_view_ids(scan_ids)
        # temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        # start_idxes, end_idxes = zip(*temp)
        # end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        start_idxes, end_idxes = [], []
        for cur_start_id, cur_end_id, cur_cand_view_ids in zip(start_ids, end_ids, cand_view_ids):
            for idx, cur_cand_id in enumerate(cur_cand_view_ids):
                if cur_cand_id == cur_start_id: start_idxes.append(idx)
                if cur_cand_id == cur_end_id: end_idxes.append(idx)
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, path_a, guess_b = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        final_ids = [cur_path_a_ids[-1] for cur_path_a_ids in path_a_ids]
        # distance1 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, start_ids, end_ids)
        # distance2 = get_shortest_path_distance(dataset.graph_paths, dataset.graphs, scan_ids, final_ids, end_ids)
        # diff_dist = [d1 - d2 for d1, d2 in zip(distance1, distance2)]
        # accum_progress += diff_dist
        next_a_ids = [[dataset.graph_paths[cur_scan_id][(cur_id, tgt_id)]['next'] for cur_id in cur_ids]
                      for cur_scan_id, cur_ids, tgt_id in zip(scan_ids, path_a_ids, end_ids)]
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
        accum_loc_tgt.append(path_a[:, 1:]); accum_loc_pred.append(guess_b)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
        accum_start_ids += start_ids; accum_end_ids += [cur_path_ids[-1] for cur_path_ids in path_a_ids]
        accum_tgt_ids += end_ids; accum_scan_ids += scan_ids; accum_path_ids += path_a_ids
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    accum_loc_pred = torch.cat(accum_loc_pred, dim=0)
    accum_loc_tgt = torch.cat(accum_loc_tgt, dim=0)
    temp_loc = torch.zeros_like(accum_loc_tgt).to(torch.float32)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_loc[accum_loc_pred == accum_loc_tgt] = 1.0
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    print("          eval loc acc = {}, nav acc = {}".format(torch.mean(temp_loc, dim=0),
                                                             torch.mean(temp_nav, dim=0)), flush=True)
    # print("          progress = {}".format(sum(accum_progress) / len(accum_progress)), flush=True)
    tl = trajectory_length(dataset.graphs, accum_scan_ids, accum_path_ids)
    ne = navigation_error(dataset.graph_paths, dataset.graphs, accum_scan_ids, accum_end_ids, accum_tgt_ids)
    sr = success_rate(dataset.graph_paths, dataset.graphs, accum_scan_ids, accum_end_ids, accum_tgt_ids)
    spl = success_weighted_by_path_length(dataset.graph_paths, dataset.graphs, accum_scan_ids, accum_start_ids,
                                          accum_end_ids, accum_tgt_ids, accum_path_ids)
    print("tl = {}, ne = {}, sr = {}, spl = {}".format(tl, ne, sr, spl), flush=True)

