from Environment import *
from NavigationMultiStep import *
from torch.optim import Adam
from LocalizationPos import *
from GuidanceMultiStep import *

def navigation_train(pre_task=None):
    dataset = FullEnvDataset("train")
    dataset_val = FullEnvDataset("validate")
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    print("dataset loaded.", flush=True)
    task = NavigationMultiStep() if pre_task is None else pre_task
    opt = Adam(filter(lambda p: p.requires_grad, task.parameters()), lr=Param.lr, betas=(0.9, 0.98), eps=1e-8,
               weight_decay=5e-4)
    task.agent_b.load_obs_emb(dataset, dataset_val, dataset.scan_ids, dataset_val.scan_ids)
    task.agent_b.load_act_emb(dataset, dataset_val, dataset.scan_ids, dataset_val.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset_val, dataset.scan_ids, dataset_val.scan_ids)
    print("emb loaded.", flush=True)
    accum_loc_tgt, accum_loc_pred = [], []
    accum_nav_tgt, accum_nav_pred = [], []
    accum_path, accum_mask = [], []
    total_loss_a, total_loss_b = 0, 0
    for i in range(Param.epoch):
        opt.zero_grad()
        for step, scan_ids in enumerate(dataloader):
            task.train()
            cand_view_ids = dataset.get_view_ids(scan_ids)
            # ------- forward -------
            temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
            start_idxes, end_idxes = zip(*temp)
            end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
            history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, path_a, guess_b, is_speak = \
                task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="sample")
            # ------ backward ------
            path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                          for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
            next_a_ids, mask = get_next_route(dataset.graph_paths, scan_ids, path_a_ids, end_ids, is_speak)
            next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
            nav_rewards, loc_rewards, msg_a_gain, msg_b_gain = \
                get_dense_rewards(path_a, next_a_idx, guess_b, mask, is_speak)
            loss_a, loss_b = task.backward(msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, msg_a_gain, msg_b_gain,
                                           loc_rewards, nav_rewards)
            accum_loc_tgt.append(path_a[:, 1:]); accum_loc_pred.append(guess_b)
            accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
            accum_path.append(path_a); accum_mask.append(mask)
            total_loss_a += loss_a; total_loss_b += loss_b
        opt.step()
        print("|", end="", flush=True)
        if i % 20 == 0:
            task.eval()
            accum_loc_tgt, accum_loc_pred = torch.cat(accum_loc_tgt, dim=0), torch.cat(accum_loc_pred, dim=0)
            accum_nav_tgt, accum_nav_pred = torch.cat(accum_nav_tgt, dim=0), torch.cat(accum_nav_pred, dim=0)
            accum_path = torch.cat(accum_path, dim=0)
            accum_mask = torch.cat(accum_mask, dim=0)
            temp_loc = torch.zeros_like(accum_loc_tgt).to(torch.float32)
            temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
            temp_loc[accum_loc_pred == accum_loc_tgt] = 1.0
            temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
            nav_score_mask = torch.stack([torch.mean(temp_nav[:, k][accum_mask[:, k] == 0.0]) for k in range(temp_nav.shape[1])])
            print()
            print("epoch{}: loc acc = {}, nav acc = {}, nav acc (mask) = {}, loss A = {}, loss B = {}"
                  .format(i, torch.mean(temp_loc, dim=0), torch.mean(temp_nav, dim=0), nav_score_mask,
                          total_loss_a, total_loss_b), flush=True)
            total_loss_a, total_loss_b = 0, 0
            accum_loc_tgt, accum_loc_pred = [], []
            accum_nav_tgt, accum_nav_pred = [], []
            accum_path, accum_mask = [], []
            with torch.no_grad():
                navigation_eval(task, dataset_val)


def navigation_eval(task, dataset):
    task.eval()
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    accum_loc_tgt, accum_loc_pred = [], []
    accum_nav_tgt, accum_nav_pred = [], []
    accum_path, accum_mask = [], []
    for step, scan_ids in enumerate(dataloader):
        cand_view_ids = dataset.get_view_ids(scan_ids)
        temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        start_idxes, end_idxes = zip(*temp)
        end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, path_a, guess_b, is_speak = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        next_a_ids, mask = get_next_route(dataset.graph_paths, scan_ids, path_a_ids, end_ids, is_speak)
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
        accum_loc_tgt.append(path_a[:, 1:]); accum_loc_pred.append(guess_b)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
        accum_path.append(path_a); accum_mask.append(mask)
    accum_loc_tgt, accum_loc_pred = torch.cat(accum_loc_tgt, dim=0), torch.cat(accum_loc_pred, dim=0)
    accum_nav_tgt, accum_nav_pred = torch.cat(accum_nav_tgt, dim=0), torch.cat(accum_nav_pred, dim=0)
    accum_path = torch.cat(accum_path, dim=0)
    accum_mask = torch.cat(accum_mask, dim=0)
    temp_loc = torch.zeros_like(accum_loc_tgt).to(torch.float32)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_loc[accum_loc_pred == accum_loc_tgt] = 1.0
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    nav_score_mask = torch.stack([torch.mean(temp_nav[:, k][accum_mask[:, k] == 0.0]) for k in range(temp_nav.shape[1])])
    print("          eval loc acc = {}, nav acc = {}, nav acc (mask) = {}"
          .format(torch.mean(temp_loc, dim=0), torch.mean(temp_nav, dim=0), nav_score_mask), flush=True)


def get_next_route(graph_paths, scan_ids, path_a_ids, end_ids, is_speak):
    is_speak_list = is_speak.tolist()
    next_route, mask = [], []
    for cur_scan_id, cur_episode_ids, tgt_id, cur_episode_is_speak in zip(scan_ids, path_a_ids, end_ids, is_speak_list):
        cur_path = graph_paths[cur_scan_id]
        next_route.append([]); mask.append([])
        for j, cur_id in enumerate(cur_episode_ids):
            if len(next_route[-1]) == 0 or (next_route[-1][-1] == cur_id and mask[-1][-1] == 0) or cur_episode_is_speak[j - 1] == 1:  # last step move right
                mask[-1].append(0)
            elif cur_id == cur_episode_ids[j - 1] and mask[-1][-1] == 0:  # last step stop
                mask[-1].append(0)
            else:  # last step move wrong
                mask[-1].append(1)
            next_route[-1].append(cur_path[(cur_id, tgt_id)]["next"])
    mask = torch.from_numpy(np.array(mask))
    return next_route, mask


def get_dense_rewards(path_a, next_a_idx, guess_b, mask, is_speak):
    nav_rewards = torch.zeros((Param.batch_size, Param.max_turns))
    loc_rewards = torch.zeros((Param.batch_size, Param.max_turns))
    nav_rewards[path_a[:, 1:] == next_a_idx[:, :-1]] = 1.0
    nav_rewards[path_a[:, 1:] != next_a_idx[:, :-1]] = - 1.1
    nav_rewards[mask[:, :-1] == 1] = 0.0
    nav_rewards[:, 0] = 0.0  # first move is forced to stop
    loc_rewards[guess_b == path_a[:, 1:]] = 1.0
    loc_rewards[guess_b != path_a[:, 1:]] = - 1.1
    loc_rewards[is_speak == False] = 0.0
    nav_rewards *= Param.reward
    loc_rewards *= Param.reward
    msg_a_gain, msg_b_gain = get_msg_gain(nav_rewards, loc_rewards, is_speak)
    return nav_rewards, loc_rewards, msg_a_gain, msg_b_gain


def get_msg_gain(nav_rewards, loc_rewards, is_speak):
    nav_rewards_list = nav_rewards.tolist()
    is_speak_list = is_speak.tolist()
    msg_a_gain, msg_b_gain = [], []
    for i in range(nav_rewards.shape[0]):
        msg_a_gain.append([])
        msg_b_gain.append([])
        for j in range(nav_rewards.shape[1]):
            reward = 0 if len(msg_a_gain[-1]) == 0 else nav_rewards_list[i][- j]
            # gain_before = 0 if len(msg_a_gain[-1]) < 2 or is_speak_list[i][- j] is True else msg_a_gain[-1][-1]
            msg_a_gain[-1].append(reward)
            reward = 0 if len(msg_b_gain[-1]) == 0 else nav_rewards_list[i][- j]
            # TODO check gain before
            gain_before = 0 if len(msg_b_gain[-1]) < 2 or is_speak_list[i][- j] is True else msg_b_gain[-1][-1]
            msg_b_gain[-1].append(reward + Param.gamma * gain_before)
        msg_a_gain[-1] = list(reversed(msg_a_gain[-1]))
        msg_b_gain[-1] = list(reversed(msg_b_gain[-1]))
    msg_a_gain = torch.from_numpy(np.array(msg_a_gain))
    msg_b_gain = torch.from_numpy(np.array(msg_b_gain))
    temp_loc_rewards = loc_rewards.clone()
    temp_loc_rewards[is_speak == False] = 0.0
    msg_a_gain += temp_loc_rewards
    msg_b_gain += torch.cat([temp_loc_rewards[:, 1:], torch.zeros((Param.batch_size, 1))], dim=1)
    msg_a_gain[is_speak == False] = 0.0
    msg_b_gain[is_speak == False] = 0.0
    return msg_a_gain, msg_b_gain


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
