from Environment import *
from LocalizationPos import *
from torch.optim import Adam


def localization_train(pre_task=None):
    dataset = FullEnvDataset("train")
    dataset_eval = FullEnvDataset("validate")
    print("dataset loaded.", flush=True)
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    task = LocalizationPos() if pre_task is None else pre_task

    opt = Adam(filter(lambda p: p.requires_grad, task.parameters()), lr=Param.lr, betas=(0.9, 0.98), eps=1e-8,
               weight_decay=5e-4)
    task.agent_b.load_obs_emb(dataset, dataset_eval, dataset.scan_ids, dataset_eval.scan_ids)
    print("obs emb loaded.", flush=True)
    accum_tgt, accum_pred = [], []
    total_loss_a, total_loss_b = 0, 0
    for i in range(Param.epoch):
        opt.zero_grad()
        for step, scan_ids in enumerate(dataloader):
            task.train()
            cand_view_ids = dataset.get_view_ids(scan_ids)
            # ------ forward ---------
            tgt_idxes = [random.choice(range(len(cur_cand_view_ids))) for cur_cand_view_ids in cand_view_ids]
            history_a, history_b, msg_a_prob, msg_b_prob, guess_b_prob, guess_b = \
                task(scan_ids, cand_view_ids, tgt_idxes, dataset, choose_method="sample")
            # ------ backward ---------
            tgt_idxes = torch.from_numpy(np.array(tgt_idxes))
            cur_pred = guess_b
            accum_tgt.append(tgt_idxes); accum_pred.append(cur_pred[:, -1])
            rewards = torch.zeros((Param.batch_size, Param.max_turns))
            rewards[cur_pred == tgt_idxes.unsqueeze(1)] = 1.0
            rewards[cur_pred != tgt_idxes.unsqueeze(1)] = - 1.0
            rewards *= Param.reward
            # loss_a, loss_b = task.backward_alg_blu(msg_a_prob, msg_b_prob, guess_b_prob, rewards)
            loss_a, loss_b = task.backward(msg_a_prob, msg_b_prob, guess_b_prob, rewards)
            total_loss_a += loss_a; total_loss_b += loss_b
        opt.step()
        print("|", end="", flush=True)
        if i % 20 == 0:
            task.eval()
            accum_pred = torch.cat(accum_pred, dim=0); accum_tgt = torch.cat(accum_tgt, dim=0)
            temp = torch.zeros_like(accum_pred).to(torch.float32)
            temp[accum_pred == accum_tgt] = 1.0
            print()
            print("epoch {}: train acc = {}, loss A = {}, loss B = {}".format(i, torch.mean(temp), total_loss_a,
                                                                              total_loss_b), flush=True)
            accum_pred, accum_tgt = [], []
            total_loss_a, total_loss_b = 0, 0
            with torch.no_grad():
                localization_eval(task, dataset_eval)


def localization_eval(task, dataset):
    task.eval()
    dataloader = DataLoader(dataset, batch_size=Param.batch_size)
    accum_tgt, accum_pred = [], []
    for step, scan_ids in enumerate(dataloader):
        cand_view_ids = dataset.get_view_ids(scan_ids)
        tgt_idxes = [random.choice(range(len(cur_cand_view_ids))) for cur_cand_view_ids in cand_view_ids]
        history_a, history_b, msg_a_prob, msg_b_prob, guess_b_prob, guess_b = \
            task(scan_ids, cand_view_ids, tgt_idxes, dataset, choose_method="greedy")
        tgt_idxes = torch.from_numpy(np.array(tgt_idxes))
        cur_pred = guess_b
        accum_tgt.append(tgt_idxes); accum_pred.append(cur_pred[:, -1])
    accum_pred = torch.cat(accum_pred, dim=0); accum_tgt = torch.cat(accum_tgt, dim=0)
    temp = torch.zeros_like(accum_pred).to(torch.float32)
    temp[accum_pred == accum_tgt] = 1.0
    print("          eval acc = {}".format(torch.mean(torch.mean(temp))), flush=True)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

