from Environment import *
from Guidance import *
from torch.optim import Adam


def guidance_train(pre_task):
    dataset = FullEnvDataset("train")
    dataset_eval = FullEnvDataset("validate")
    print("dataset loaded.", flush=True)
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    task = Guidance() if pre_task is None else pre_task
    opt = Adam(filter(lambda p: p.requires_grad, task.parameters()), lr=Param.lr, betas=(0.9, 0.98), eps=1e-8,
               weight_decay=5e-4)
    task.agent_a.load_act_emb(dataset, dataset_eval, dataset.scan_ids, dataset_eval.scan_ids)
    task.agent_b.load_act_emb(dataset, dataset_eval, dataset.scan_ids, dataset_eval.scan_ids)
    print("act emb loaded.", flush=True)
    accum_nav_tgt, accum_nav_pred = [], []
    total_loss_a, total_loss_b = 0, 0
    Param.max_turns = 3
    for i in range(Param.epoch):
        opt.zero_grad()
        for step, scan_ids in enumerate(dataloader):
            task.train()
            cand_view_ids = dataset.get_view_ids(scan_ids)
            # ------ forward ---------
            temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
            start_idxes, end_idxes = zip(*temp)
            end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
            history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, path_a = \
                task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="sample")
            # ----- backward ---------
            path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                          for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
            next_a_ids = [[dataset.graph_paths[cur_scan_id][(cur_id, tgt_id)]["next"] for cur_id in cur_ids]
                          for cur_scan_id, cur_ids, tgt_id in zip(scan_ids, path_a_ids, end_ids)]
            next_a_idx = id2idx2d(next_a_ids, cand_view_ids)

            rewards = torch.zeros((Param.batch_size, Param.max_turns))
            rewards[path_a[:, 1:] == next_a_idx[:, :-1]] = 1.0
            rewards[path_a[:, 1:] != next_a_idx[:, :-1]] = - 1.0
            rewards *= Param.reward
            # loss_a, loss_b = task.backward_alu_blg(msg_a_prob, msg_b_prob, act_a_prob, rewards)
            loss_a, loss_b = task.backward(msg_a_prob, msg_b_prob, act_a_prob, rewards)
            accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
            total_loss_a += loss_a; total_loss_b += loss_b
        opt.step()
        print("|", end="", flush=True)
        if i % 20 == 0:
            task.eval()
            accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
            accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
            temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
            temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
            print()
            print("epoch{}: nav acc = {}, loss A = {}, loss B = {}".format(i, torch.mean(temp_nav, dim=0),
                                                                           total_loss_a, total_loss_b), flush=True)
            accum_nav_tgt, accum_nav_pred = [], []
            total_loss_a, total_loss_b = 0, 0
            with torch.no_grad():
                guidance_eval(task, dataset_eval)
        if i == 500:
            Param.max_turns = 20


def guidance_eval(task, dataset):
    task.eval()
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    accum_nav_tgt, accum_nav_pred = [], []
    for step, scan_ids in enumerate(dataloader):
        cand_view_ids = dataset.get_view_ids(scan_ids)
        temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        start_idxes, end_idxes = zip(*temp)
        end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, path_a = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        next_a_ids = [[dataset.graph_paths[cur_scan_id][(cur_id, tgt_id)]["next"] for cur_id in cur_ids]
                      for cur_scan_id, cur_ids, tgt_id in zip(scan_ids, path_a_ids, end_ids)]
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    print("          eval nav acc = {}".format(torch.mean(temp_nav, dim=0)), flush=True)

