from Environment import *
from Navigation import *
from torch.optim import Adam
from LocalizationPos import *
from Guidance import *


def navigation_train(pre_task=None, start_epoch=0):
    dataset = FullEnvDataset("train")
    dataset_eval = FullEnvDataset("validate")
    print("dataset loaded.", flush=True)
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    task = Navigation() if pre_task is None else pre_task
    opt = Adam(filter(lambda p: p.requires_grad, task.parameters()), lr=Param.lr, betas=(0.9, 0.98), eps=1e-8,
               weight_decay=5e-4)
    task.agent_b.load_obs_emb(dataset, dataset_eval, dataset.scan_ids, dataset_eval.scan_ids)
    task.agent_b.load_act_emb(dataset, dataset_eval, dataset.scan_ids, dataset_eval.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset_eval, dataset.scan_ids, dataset_eval.scan_ids)
    print("act emb loaded.", flush=True)
    accum_loc_tgt, accum_loc_pred = [], []
    accum_nav_tgt, accum_nav_pred = [], []
    total_loss_a, total_loss_b = 0, 0
    Param.max_turns = 3
    for i in range(start_epoch, Param.epoch):
        opt.zero_grad()
        for step, scan_ids in enumerate(dataloader):
            task.train()
            cand_view_ids = dataset.get_view_ids(scan_ids)
            # ------- forward ---------------
            temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
            start_idxes, end_idxes = zip(*temp)
            end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
            history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, path_a, guess_b = \
                task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="sample")
            # ------- backward ---------------
            path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                          for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
            next_a_ids = [[dataset.graph_paths[cur_scan_id][(cur_id, tgt_id)]['next'] for cur_id in cur_ids]
                          for cur_scan_id, cur_ids, tgt_id in zip(scan_ids, path_a_ids, end_ids)]
            next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)

            rewards_loc = torch.zeros((Param.batch_size, Param.max_turns))
            rewards_nav = torch.zeros((Param.batch_size, Param.max_turns))
            rewards_loc[path_a[:, 1:] == guess_b] = 1.0
            rewards_loc[path_a[:, 1:] != guess_b] = - 1.1
            rewards_nav[path_a[:, 1:] == next_a_idx[:, :-1]] = 1.0
            rewards_nav[path_a[:, 1:] != next_a_idx[:, :-1]] = - 1.1
            rewards_loc *= Param.reward
            rewards_nav *= Param.reward
            loss_a, loss_b = task.backward(msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, rewards_loc, rewards_nav)
            accum_loc_tgt.append(path_a[:, 1:]); accum_loc_pred.append(guess_b)
            accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
            total_loss_a += loss_a; total_loss_b += loss_b
        opt.step()
        print("|", end="", flush=True)
        if i % 20 == 0:
            task.eval()
            accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
            accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
            accum_loc_pred = torch.cat(accum_loc_pred, dim=0)
            accum_loc_tgt = torch.cat(accum_loc_tgt, dim=0)
            temp_loc = torch.zeros_like(accum_loc_tgt).to(torch.float32)
            temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
            temp_loc[accum_loc_pred == accum_loc_tgt] = 1.0
            temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
            print()
            print("epoch{}: loc acc = {}, nav acc = {}, loss A = {}, loss B = {}".format(i,
                                                                                         torch.mean(temp_loc, dim=0),
                                                                                         torch.mean(temp_nav, dim=0),
                                                                                         total_loss_a, total_loss_b),
                  flush=True)
            accum_loc_tgt, accum_loc_pred, accum_nav_tgt, accum_nav_pred = [], [], [], []
            total_loss_a, total_loss_b = 0.0, 0.0
            with torch.no_grad():
                navigation_eval(task, dataset_eval)
        if i == 500:
            Param.max_turns = 10


def navigation_eval(task, dataset):
    task.eval()
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    accum_loc_tgt, accum_loc_pred = [], []
    accum_nav_tgt, accum_nav_pred = [], []
    for step, scan_ids in enumerate(dataloader):
        cand_view_ids = dataset.get_view_ids(scan_ids)
        temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        start_idxes, end_idxes = zip(*temp)
        end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, guess_b_prob, path_a, guess_b = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        next_a_ids = [[dataset.graph_paths[cur_scan_id][(cur_id, tgt_id)]['next'] for cur_id in cur_ids]
                      for cur_scan_id, cur_ids, tgt_id in zip(scan_ids, path_a_ids, end_ids)]
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
        accum_loc_tgt.append(path_a[:, 1:]); accum_loc_pred.append(guess_b)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    accum_loc_pred = torch.cat(accum_loc_pred, dim=0)
    accum_loc_tgt = torch.cat(accum_loc_tgt, dim=0)
    temp_loc = torch.zeros_like(accum_loc_tgt).to(torch.float32)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_loc[accum_loc_pred == accum_loc_tgt] = 1.0
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    print("          eval loc acc = {}, nav acc = {}".format(torch.mean(temp_loc, dim=0),
                                                             torch.mean(temp_nav, dim=0)), flush=True)

