import os
import sys
import torch
from torch import nn
from Environment import *
from torch.optim import Adam
from Components import *
import torch.nn.functional as F
device = 'cuda:1'

def nav_obs_train():
    dataset = FullEnvDataset("train")
    print("data loaded.", flush=True)
    loader = DataLoader(dataset, batch_size=Param.batch_size)
    obs_encoder_a = NavObsTransEncoder().to(device)
    obs_encoder_b = NavObsTransEncoder().to(device)
    stop_token_a = nn.Embedding(1, Param.emb_size).to(device)
    stop_token_b = nn.Embedding(1, Param.emb_size).to(device)
    opt = Adam([{'params': obs_encoder_a.parameters()}, {'params': obs_encoder_b.parameters()},
                {'params': stop_token_a.parameters()}, {'params': stop_token_b.parameters()}],
               lr=Param.lr, betas=(0.9, 0.98), eps=1e-8, weight_decay=5e-4)
    loss_fn = nn.CrossEntropyLoss()
    max_val_acc = 0.0
    accum_is_right_loc, accum_is_right_nav = [], []
    for i in range(Param.epoch):
        for step, scan_ids in enumerate(loader):
            obs_encoder_a.train(); obs_encoder_b.train(); stop_token_a.train(); stop_token_b.train()
            obs_encoder_a.to(device); obs_encoder_b.to(device); stop_token_a.to(device); stop_token_b.to(device)
            opt.zero_grad()
            # cand_view_ids = [[trans_cand_view_ids[j][k] for j in range(len(trans_cand_view_ids))] for k in
            #                  range(len(trans_cand_view_ids[0]))]
            cand_view_ids = dataset.get_view_ids(scan_ids)
            tgt_idxes = [random.choice(range(len(cand_view_ids[k]))) for k in range(len(scan_ids))]
            tgt_ids = [cur_cand_view_ids[j] for j, cur_cand_view_ids in zip(tgt_idxes, cand_view_ids)]
            end_ids = [random.choice(cur_cand_view_ids) for cur_cand_view_ids in cand_view_ids]
            poses = [(random.choice(range(0, 360, 30)), random.choice([-30, 0, 30])) for _ in
                     range(Param.batch_size)]  # (batch)
            # -------- localization ----------
            obs_a = dataset.collect_single_obs_multi_views(scan_ids, tgt_ids, poses).to(torch.float32).to(device)  # (batch, 36, feat len)
            obs_b = dataset.collect_obs(scan_ids, cand_view_ids).to(torch.float32).to(device)  # (batch, cand len, 36, feat len)
            obs_emb_a = obs_encoder_a(obs_a)
            obs_b_shape = obs_b.shape
            obs_emb_b = obs_encoder_b(obs_b.view(obs_b.shape[0] * obs_b.shape[1], obs_b.shape[2], obs_b.shape[3]))
            obs_emb_b = obs_emb_b.view((obs_b_shape[0], obs_b.shape[1], Param.emb_size))
            # obs_emb_b = []
            # for b in range(obs_b.shape[0]):
            #     obs_emb_b.append(obs_encoder_b(obs_b[b, :, :, :]).unsqueeze(0))
            # obs_emb_b = torch.cat(obs_emb_b, dim=0)  # (batch, cand len, emb dim)
            # scores_loc = torch.softmax(torch.bmm(obs_emb_b, obs_emb_a.unsqueeze(2)).squeeze(2), dim=1)  # (batch, cand len)
            scores_loc = torch.bmm(obs_emb_b, obs_emb_a.unsqueeze(2)).squeeze(2)
            # no need to mask here
            loss_loc = loss_fn(scores_loc, torch.from_numpy(np.array(tgt_idxes)).to(device))
            # -------- navigation ------------
            paths = []
            for cur_scan_id, cur_end_id, cur_cand_ids in zip(scan_ids, end_ids, cand_view_ids):
                cur_path_dict = dataset.graph_paths[cur_scan_id]
                cur_path = []
                for cur_cand_id in cur_cand_ids:  # TODO, there is node without neighbors, so check codes before
                    cur_path.append(cur_path_dict[(cur_cand_id, cur_end_id)] if cur_cand_id != cur_end_id else None)
                paths.append(cur_path)  # (batch, cand len)
            nav_emb_b = collect_path_obs(scan_ids, cand_view_ids, paths, dataset, obs_encoder_b, stop_token_b)  # (batch, cand len, emb dim)
            # if step % 2 == 0:
            true_loc = torch.zeros((Param.batch_size, Param.max_num_node)).to(torch.float32).to(device)
            true_loc[torch.range(0, Param.batch_size - 1).to(torch.int64), tgt_idxes] = 1.0
            nav_emb_b = torch.bmm(true_loc.unsqueeze(1), nav_emb_b).squeeze(1)  # (batch, emb dim)
            nav_emb_a, neighbor_ids = collect_candidate(scan_ids, tgt_ids, dataset, obs_encoder_a, stop_token_a)  # (batch, neighbor len, emb dim)
            cand_mask, max_neigh_len = [], nav_emb_a.shape[1]
            for cur_cand_neigh_ids in neighbor_ids:
                cand_mask.append([0] * len(cur_cand_neigh_ids) + [1] * (max_neigh_len - len(cur_cand_neigh_ids)))
            cand_mask = torch.from_numpy(np.array(cand_mask))  # (batch, neigh len)
            scores_nav = torch.bmm(nav_emb_a, nav_emb_b.unsqueeze(2)).squeeze(2)  # (batch, neighbor_len)
            scores_nav[cand_mask == 1] = -math.inf
            # scores_nav = torch.softmax(scores_nav, dim=1)
            tgt_next_ids = [dataset.graph_paths[cur_scan_id][(cur_tgt_id, cur_end_id)]['next'] for cur_scan_id, cur_tgt_id, cur_end_id
                            in zip(scan_ids, tgt_ids, end_ids)]
            tgt_next_idx = id2idx1d(tgt_next_ids, neighbor_ids)
            loss_nav = loss_fn(scores_nav, tgt_next_idx.to(device))
            # ---------------
            loss = loss_loc + loss_nav
            # loss = loss_nav
            loss.backward()
            opt.step()
            pred_loc = torch.argmax(scores_loc, dim=1).cpu()  # (batch, )
            pred_nav = torch.argmax(scores_nav, dim=1).cpu()  # (batch, )
            tgt_loc = torch.from_numpy(np.array(tgt_idxes))
            tgt_nav = tgt_next_idx
            is_right_loc = torch.zeros_like(pred_loc).to(torch.float32)
            is_right_nav = torch.zeros_like(pred_nav).to(torch.float32)
            is_right_loc[pred_loc == tgt_loc] = 1.0
            is_right_nav[pred_nav == tgt_nav] = 1.0
            accum_is_right_loc.append(is_right_loc)
            accum_is_right_nav.append(is_right_nav)
        print("|", end="", flush=True)
        if i % 20 == 0:
            print("")
            accum_is_right_nav = torch.cat(accum_is_right_nav, dim=0)
            accum_is_right_loc = torch.cat(accum_is_right_loc, dim=0)
            print("epoch{}: train loc acc = {}, nav acc = {}".format(i, torch.mean(accum_is_right_loc),
                                                                     torch.mean(accum_is_right_nav)), flush=True)
            # print("epoch{}: train nav acc = {}".format(i, torch.mean(accum_is_right_nav)), flush=True)
            accum_is_right_nav, accum_is_right_loc = [], []
            with torch.no_grad():
                val_loc_acc, val_nav_acc = nav_obs_eval(obs_encoder_a, obs_encoder_b, stop_token_a, stop_token_b)
                # val_nav_acc = nav_obs_eval(obs_encoder_a, obs_encoder_b, stop_token_a, stop_token_b)
                if (val_loc_acc + val_nav_acc) / 2 > max_val_acc:
                    torch.save({"model": obs_encoder_a.cpu(), "stop_token": stop_token_a.cpu()}, Param.obs_encode_model_a_pth)
                    torch.save({"model": obs_encoder_b.cpu(), "stop_token": stop_token_b.cpu()}, Param.obs_encode_model_b_pth)
                    max_val_acc = (val_loc_acc + val_nav_acc) / 2
                # if val_nav_acc > max_val_acc:
                #     torch.save({"model": obs_encoder_a, "stop_token": stop_token_a}, Param.obs_encode_model_a_pth)
                #     torch.save({"model": obs_encoder_b, "stop_token": stop_token_b}, Param.obs_encode_model_b_pth)
                #     max_val_acc = val_nav_acc


def nav_obs_eval(obs_encoder_a, obs_encoder_b, stop_token_a, stop_token_b):
    dataset = FullEnvDataset("validate")
    loader = DataLoader(dataset, batch_size=Param.batch_size)
    accum_is_right_loc, accum_is_right_nav = [], []
    for step, scan_ids in enumerate(loader):
        obs_encoder_a.eval(); obs_encoder_b.eval(); stop_token_a.eval(); stop_token_b.eval()
        cand_view_ids = dataset.get_view_ids(scan_ids)
        tgt_idxes = [random.choice(range(len(cand_view_ids[k]))) for k in range(len(scan_ids))]
        tgt_ids = [cur_cand_view_ids[j] for j, cur_cand_view_ids in zip(tgt_idxes, cand_view_ids)]
        end_ids = [random.choice(cur_cand_view_ids) for cur_cand_view_ids in cand_view_ids]
        poses = [(random.choice(range(0, 360, 30)), random.choice([-30, 0, 30])) for _ in
                 range(Param.batch_size)]  # (batch)
        # --------- localization -------------------
        obs_a = dataset.collect_single_obs_multi_views(scan_ids, tgt_ids, poses).to(torch.float32).to(device)  # (batch, 36, feat len)
        obs_b = dataset.collect_obs(scan_ids, cand_view_ids).to(torch.float32).to(device)  # (batch, cand len, 36, feat len)
        obs_emb_a = obs_encoder_a(obs_a)
        obs_emb_b = []
        for b in range(obs_b.shape[0]):
            obs_emb_b.append(obs_encoder_b(obs_b[b, :, :, :]).unsqueeze(0))
        obs_emb_b = torch.cat(obs_emb_b, dim=0)  # (batch, cand len, emb dim)
        scores_loc = torch.bmm(obs_emb_b, obs_emb_a.unsqueeze(2)).squeeze(2)
        scores_loc_softmax = torch.softmax(scores_loc, dim=1)
        # --------- navigation --------------------
        paths = []
        for cur_scan_id, cur_end_id, cur_cand_ids in zip(scan_ids, end_ids, cand_view_ids):
            cur_path_dict = dataset.graph_paths[cur_scan_id]
            cur_path = []
            for cur_cand_id in cur_cand_ids:  # TODO, there is node without neighbors, so check codes before
                cur_path.append(cur_path_dict[(cur_cand_id, cur_end_id)] if cur_cand_id != cur_end_id else None)
            paths.append(cur_path)  # (batch, cand len)
        nav_emb_b = collect_path_obs(scan_ids, cand_view_ids, paths, dataset, obs_encoder_b,
                                     stop_token_b)  # (batch, cand len, emb dim)
        nav_emb_b = torch.bmm(scores_loc_softmax.unsqueeze(1), nav_emb_b).squeeze(1)  # (batch, emb dim)
        nav_emb_a, neighbor_ids = collect_candidate(scan_ids, tgt_ids, dataset, obs_encoder_a,
                                                    stop_token_a)  # (batch, neighbor len, emb dim)
        cand_mask, max_neigh_len = [], nav_emb_a.shape[1]
        for cur_cand_neigh_ids in neighbor_ids:
            cand_mask.append([0] * len(cur_cand_neigh_ids) + [1] * (max_neigh_len - len(cur_cand_neigh_ids)))
        cand_mask = torch.from_numpy(np.array(cand_mask))  # (batch, neigh len)
        scores_nav = torch.bmm(nav_emb_a, nav_emb_b.unsqueeze(2)).squeeze(2)  # (batch, neighbor_len)
        scores_nav[cand_mask == 1] = -math.inf
        # scores_nav = torch.softmax(scores_nav, dim=1)
        tgt_next_ids = [dataset.graph_paths[cur_scan_id][(cur_tgt_id, cur_end_id)]['next'] for
                        cur_scan_id, cur_tgt_id, cur_end_id
                        in zip(scan_ids, tgt_ids, end_ids)]
        tgt_next_idx = id2idx1d(tgt_next_ids, neighbor_ids)
        # ------------------------
        pred_loc = torch.argmax(scores_loc, dim=1).cpu()
        pred_nav = torch.argmax(scores_nav, dim=1).cpu()
        tgt_loc = torch.from_numpy(np.array(tgt_idxes))
        tgt_nav = tgt_next_idx
        is_right_loc = torch.zeros_like(pred_loc).to(torch.float32)
        is_right_nav = torch.zeros_like(pred_nav).to(torch.float32)
        is_right_loc[pred_loc == tgt_loc] = 1.0
        is_right_nav[pred_nav == tgt_nav] = 1.0
        accum_is_right_loc.append(is_right_loc)
        accum_is_right_nav.append(is_right_nav)
    accum_is_right_nav = torch.cat(accum_is_right_nav, dim=0)
    accum_is_right_loc = torch.cat(accum_is_right_loc, dim=0)
    print("          eval loc acc = {}, nav acc = {}".format(torch.mean(accum_is_right_loc),
                                                             torch.mean(accum_is_right_nav)), flush=True)
    return torch.mean(accum_is_right_loc), torch.mean(accum_is_right_nav)



def collect_path_obs(scan_ids, cand_view_ids, paths, env, observe_model, stop_token):
    """
    :param scan_ids: (batch, )
    :param cand_view_ids: (batch, cand len)
    :param paths: (batch, cand len) {"next":, "angles": [..]}
    :return: (batch, cand len, emb dim)
    """
    obs, view_idx, stop_pos = [], [], [0] * len(scan_ids)
    for i, (path, scan_id, cand_view_id) in enumerate(zip(paths, scan_ids, cand_view_ids)):  # batch
        batch_obs, batch_view_idx = [], []
        for j, (cur_path, cur_view_id) in enumerate(zip(path, cand_view_id)):  # cand len
            cur_cand_obs, cur_cand_view_idx = [], []
            if cur_path is not None:
                for ix, h, e, d in cur_path['angles']:
                    cur_cand_obs.append(torch.from_numpy(env.features[scan_id][cur_view_id][(h, e)]))
                    cur_cand_view_idx.append(h / 30 * 3 + e / 30 + 1)
                cur_cand_obs = torch.stack(cur_cand_obs, dim=0)  # (6, 1000)
                cur_cand_view_idx = torch.from_numpy(np.array(cur_cand_view_idx))  # (6, )
            else:
                stop_pos[i] = j
                cur_cand_obs = torch.zeros((Param.max_path_img_num, 1000))
                cur_cand_view_idx = torch.zeros((Param.max_path_img_num, ))
            batch_obs.append(cur_cand_obs)
            batch_view_idx.append(cur_cand_view_idx)
        batch_obs = torch.stack(batch_obs, dim=0)  # (cand len, 6, 1000)
        batch_view_idx = torch.stack(batch_view_idx, dim=0)  # (cand len, 6)
        batch_obs = torch.cat([batch_obs, torch.zeros((Param.max_num_node - batch_obs.shape[0], 6, 1000))], dim=0)
        batch_view_idx = torch.cat([batch_view_idx, torch.zeros((Param.max_num_node - batch_view_idx.shape[0], 6))], dim=0)
        obs.append(batch_obs)
        view_idx.append(batch_view_idx)
    obs = torch.stack(obs, dim=0).to(torch.float32).to(device)  # (batch, cand len, 6, 1000)
    view_idx = torch.stack(view_idx, dim=0).to(torch.int64).to(device)  # (batch, cand len, 6)
    # obs_emb = []
    # for i in range(obs.shape[0]):
    #     cur_obs_emb = observe_model(obs[i, :, :, :].to(torch.float32), view_idx[i, :, :].to(torch.int))
    #     obs_emb.append(cur_obs_emb.unsqueeze(0))  # (neigh len, emb dim)
    # obs_emb = torch.cat(obs_emb, dim=0)  # (batch, neigh len, emb dim)
    obs_emb = observe_model(obs.view((obs.shape[0] * obs.shape[1], obs.shape[2], obs.shape[3])),
                            view_idx.view((obs.shape[0] * obs.shape[1], obs.shape[2])))
    obs_emb = obs_emb.view((obs.shape[0], obs.shape[1], Param.emb_size))
    stop_pos = torch.from_numpy(np.array(stop_pos)).to(device)
    obs_emb[torch.range(0, len(scan_ids) - 1).to(torch.int64), stop_pos, :] = \
        stop_token(torch.zeros((len(scan_ids), )).to(torch.int64).to(device))
    return obs_emb


def collect_candidate(scan_ids, view_ids, env, observe_model, stop_token):
    """
    :param scan_ids: (batch, )
    :param view_ids: (batch, )
    :return: (batch, neighbor len, emb dim)
    """
    neighbor_ids, neighbor_obses, neighbor_view_idxes = [], [], []
    for scan_id, view_id in zip(scan_ids, view_ids):  # for batch
        cur_graph = env.graphs[scan_id]
        neighbors = cur_graph[view_id]
        neighbor_ids.append(list(neighbors.keys()))
        # neighbor_masks.append([0] * len(neighbor_ids[-1]) + [1] * (Param.max_neigh_num - len(neighbor_ids[-1])))
        neigh_obs, neigh_view_idx = [], []
        for neigh_id in neighbor_ids[-1]:
            neigh_angles = neighbors[neigh_id]
            assert len(neigh_angles) == Param.max_path_img_num
            cur_neigh_obs, cur_neigh_view_idx = [], []
            for ix, h, e, d in neigh_angles:
                cur_neigh_obs.append(torch.from_numpy(env.features[scan_id][neigh_id][(h, e)]))
                cur_neigh_view_idx.append(h / 30 * 3 + e / 30 + 1)  # NOTE do not deal with the order of view id
            cur_neigh_obs = torch.stack(cur_neigh_obs, dim=0)  # (6, emb dim)
            cur_neigh_view_idx = torch.from_numpy(np.array(cur_neigh_view_idx))  # (6, )
            neigh_obs.append(cur_neigh_obs)
            neigh_view_idx.append(cur_neigh_view_idx)
        pad_len = Param.max_neigh_num - len(neigh_obs)
        neigh_obs = F.pad(torch.stack(neigh_obs, dim=0), (0, 0, 0, 0, 0, pad_len))  # (neigh len, 6, emb dim)
        neigh_view_idx = F.pad(torch.stack(neigh_view_idx, dim=0), (0, 0, 0, pad_len))   # (neigh len, 6)
        neighbor_obses.append(neigh_obs)
        neighbor_view_idxes.append(neigh_view_idx)
    neighbor_obses = torch.stack(neighbor_obses, dim=0).to(device) # (batch, neigh len, 6, emb dim)
    neighbor_view_idxes = torch.stack(neighbor_view_idxes, dim=0).to(device)  # (batch, neigh len, 6)
    neighbor_embs = []
    for i in range(Param.max_neigh_num):
        cur_obs_emb = observe_model(neighbor_obses[:, i, :, :].to(torch.float32),
                                         neighbor_view_idxes[:, i, :].to(torch.int64))
        neighbor_embs.append(cur_obs_emb.unsqueeze(1))  # (batch, emb dim)
    neighbor_embs = torch.cat(neighbor_embs, dim=1).to(device)  # (batch, neigh len, emb dim)
    stop_tokens = stop_token(torch.zeros(neighbor_embs.shape[0], 1).to(torch.int64).to(device))  # (batch, 1, emb dim)
    neighbor_embs = torch.cat([stop_tokens, neighbor_embs], dim=1)
    neighbor_ids = [[cur_view_id] + cur_neigh_ids for cur_view_id, cur_neigh_ids in zip(view_ids, neighbor_ids)]
    return neighbor_embs, neighbor_ids


def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

