from Environment import *
from Param import *
from GuidanceMultiStep import *
from torch.optim import Adam


def guidance_train(pre_task=None):
    dataset = FullEnvDataset("train")
    dataset_eval = FullEnvDataset("validate")
    print("dataset loaded.", flush=True)
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    task = GuidanceMultiStep() if pre_task is None else pre_task
    opt = Adam(filter(lambda p: p.requires_grad, task.parameters()), lr=Param.lr, betas=(0.9, 0.98), eps=1e-8,
               weight_decay=5e-4)
    task.agent_b.load_act_emb(dataset, dataset_eval, dataset.scan_ids, dataset_eval.scan_ids)
    task.agent_a.load_act_emb(dataset, dataset_eval, dataset.scan_ids, dataset_eval.scan_ids)
    print("act emb loaded.", flush=True)
    accum_nav_tgt, accum_nav_pred, accum_mask = [], [], []
    total_loss_a, total_loss_b = 0, 0
    # Param.max_turns = 3
    for i in range(Param.epoch):
        opt.zero_grad()
        for step, scan_ids in enumerate(dataloader):
            task.train()
            cand_view_ids = dataset.get_view_ids(scan_ids)
            # ---------- forward ---------------
            temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
            start_idxes, end_idxes = zip(*temp)
            end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
            history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, path_a, is_speak = \
                task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="sample")
            # ---------- backward --------------
            path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                          for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
            next_a_ids, mask = get_next_route(dataset.graph_paths, scan_ids, path_a_ids, end_ids, is_speak)
            next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
            nav_rewards, msg_a_gain, msg_b_gain = get_dense_rewards(path_a, next_a_idx, mask, is_speak)
            loss_a, loss_b = task.backward_alu_blg(msg_a_prob, msg_b_prob, act_a_prob, msg_a_gain, msg_b_gain, nav_rewards)
            # loss_a, loss_b = task.backward(msg_a_prob, msg_b_prob, act_a_prob, msg_a_gain, msg_b_gain, nav_rewards)
            accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
            accum_mask.append(mask[:, :-1])
            total_loss_a += loss_a; total_loss_b += loss_b
        opt.step()
        print("|", end="", flush=True)
        if i % 200 == 0 and Param.gamma < 0.9:
            Param.gamma += 0.1
        if i % 20 == 0:
            task.eval()
            accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
            accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
            accum_mask = torch.cat(accum_mask, dim=0)
            temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
            temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
            nav_score = [torch.mean(temp_nav[:, k][accum_mask[:, k] == 0]) for k in range(temp_nav.shape[1])]
            print()
            print("epoch{}: nav acc = {}, loss A = {}, loss B = {}".format(i, nav_score,
                                                                           total_loss_a, total_loss_b), flush=True)
            accum_nav_tgt, accum_nav_pred, accum_mask = [], [], []
            total_loss_a, total_loss_b = 0, 0
            with torch.no_grad():
                guidance_eval(task, dataset_eval)
            

def guidance_eval(task, dataset):
    task.eval()
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    accum_nav_tgt, accum_nav_pred, accum_mask = [], [], []
    for step, scan_ids in enumerate(dataloader):
        cand_view_ids = dataset.get_view_ids(scan_ids)
        temp = [random.choices(range(len(cur_cand_view_ids)), k=2) for cur_cand_view_ids in cand_view_ids]
        start_idxes, end_idxes = zip(*temp)
        end_ids = [cur_cand_view_ids[j] for cur_cand_view_ids, j in zip(cand_view_ids, end_idxes)]
        history_a, history_b, msg_a_prob, msg_b_prob, act_a_prob, path_a, is_speak = \
            task(scan_ids, cand_view_ids, start_idxes, end_idxes, dataset, choose_method="greedy")
        path_a_ids = [[cur_cand_view_ids[cur_idx] for cur_idx in batch_path_a]
                      for batch_path_a, cur_cand_view_ids in zip(path_a.tolist(), cand_view_ids)]
        next_a_ids, mask = get_next_route(dataset.graph_paths, scan_ids, path_a_ids, end_ids, is_speak)
        next_a_idx = id2idx2d(next_a_ids, cand_view_ids)  # (batch, turns + 1)
        accum_nav_tgt.append(next_a_idx[:, :-1]); accum_nav_pred.append(path_a[:, 1:])
        accum_mask.append(mask[:, :-1])
    accum_nav_pred = torch.cat(accum_nav_pred, dim=0)
    accum_nav_tgt = torch.cat(accum_nav_tgt, dim=0)
    accum_mask = torch.cat(accum_mask, dim=0)
    temp_nav = torch.zeros_like(accum_nav_tgt).to(torch.float32)
    temp_nav[accum_nav_pred == accum_nav_tgt] = 1.0
    nav_score = [torch.mean(temp_nav[:, k][accum_mask[:, k] == 0]) for k in range(temp_nav.shape[1])]
    print("          eval nav acc = {}".format(nav_score), flush=True)


def get_dense_rewards(path_a, next_a_idx, mask, is_speak):
    nav_rewards = torch.zeros((Param.batch_size, Param.max_turns))
    nav_rewards[path_a[:, 1:] == next_a_idx[:, :-1]] = 1.0
    nav_rewards[path_a[:, 1:] != next_a_idx[:, :-1]] = - 1.0
    nav_rewards[mask[:, :-1] == 1] = 0.0
    nav_rewards *= Param.reward
    msg_a_gain, msg_b_gain = get_msg_gain(nav_rewards, is_speak)
    return nav_rewards, msg_a_gain, msg_b_gain


def get_msg_gain(nav_rewards, is_speak):
    nav_rewards_list = nav_rewards.tolist()
    is_speak_list = is_speak.tolist()
    msg_a_gain, msg_b_gain = [], []
    for i in range(nav_rewards.shape[0]):
        msg_a_gain.append([])
        msg_b_gain.append([])
        for j in range(nav_rewards.shape[1]):
            reward = 0 if len(msg_a_gain[-1]) == 0 else nav_rewards_list[i][- (j + 1)]
            gain_before = 0 if len(msg_a_gain[-1]) < 2 or is_speak_list[i][- j] is True else msg_a_gain[-1][-1]
            msg_a_gain[-1].append(reward + Param.gamma * gain_before)
            gain_before = 0 if len(msg_b_gain[-1]) == 0 or is_speak_list[i][- (j + 1)] is True else msg_b_gain[-1][-1]
            msg_b_gain[-1].append(nav_rewards_list[i][- (j + 1)] + Param.gamma * gain_before)
        msg_a_gain[-1] = list(reversed(msg_a_gain[-1]))
        msg_b_gain[-1] = list(reversed(msg_b_gain[-1]))
    msg_a_gain = torch.from_numpy(np.array(msg_a_gain))
    msg_b_gain = torch.from_numpy(np.array(msg_b_gain))
    msg_a_gain[is_speak == False] = 0.0
    msg_b_gain[torch.cat([torch.full((Param.batch_size, 1), True), is_speak[:, :-1]], dim=1) == False] = 0.0
    return msg_a_gain, msg_b_gain


def get_next_route(graph_paths, scan_ids, path_a_ids, end_ids, is_speak):
    is_speak_list = is_speak.tolist()
    next_route, mask = [], []
    for cur_scan_id, cur_episode_ids, tgt_id, cur_episode_is_speak in zip(scan_ids, path_a_ids, end_ids, is_speak_list):
        cur_path = graph_paths[cur_scan_id]
        next_route.append([]); mask.append([])
        for j, cur_id in enumerate(cur_episode_ids):
            if len(next_route[-1]) == 0 or (next_route[-1][-1] == cur_id and mask[-1][-1] == 0) or cur_episode_is_speak[j - 1] == 1:  # last step move right
                mask[-1].append(0)
            elif cur_id == cur_episode_ids[j - 1] and mask[-1][-1] == 0:  # last step stop
                mask[-1].append(0)
            else:  # last step move wrong
                mask[-1].append(1)
            next_route[-1].append(cur_path[(cur_id, tgt_id)]["next"])
    mask = torch.from_numpy(np.array(mask))
    return next_route, mask


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

