import torch
import json
import numpy as np
from Param import *
from torch.distributions import Categorical
import random


def cat_history(history, msg, is_speak=None):
    """
    :param history: (batch, max len)
    :param msg: (batch, sent len)
    :return:
    """
    history_list = history.tolist()
    msg_list = msg.tolist()
    is_speak_list = is_speak.tolist() if is_speak is not None else None
    start_pos = torch.ones_like(history)
    start_pos[history == Param.tokens["<msg A>"]["start pos"]] = 0
    start_pos = torch.sum(start_pos, dim=1).tolist()
    for i in range(history.shape[0]):
        if is_speak_list is None or is_speak_list[i] == 1:
            history_list[i][start_pos[i]: start_pos[i] + msg.shape[1]] = msg_list[i]
    new_history = torch.from_numpy(np.array(history_list))
    return new_history

def mask_to_length(mask):
    """
    :param mask: (batch size, sent len), [[True, True, ..., False, ..], ...]
    :return:
    """
    temp = torch.zeros_like(mask, dtype=torch.int64)
    temp[mask==True] = 1
    lengths = torch.sum(temp, dim=1)
    return lengths


def choose_token(token_score, token_type, choose_method="sample"):
    """
    :param token_score: (batch, cand len)
    :param token_type:
    :param choose_method: sample or greedy
    :return:
    """
    start_pos = Param.tokens[token_type]["start pos"]; end_pos = Param.tokens[token_type]["end pos"]
    if Param.is_output_msg0 is False and (token_type == "<msg A>" or token_type == "<msg B>"):
        start_pos = Param.tokens[token_type]["start pos"] + 1
    actual_token_score = token_score[:, start_pos:end_pos+1]
    assert actual_token_score.shape[1] == (end_pos + 1 - start_pos)
    sampler = Categorical(actual_token_score)
    if choose_method == "sample": next_token = sampler.sample()
    else: next_token = torch.argmax(actual_token_score, dim=1)
    next_token += start_pos
    return next_token, sampler.log_prob(next_token - start_pos)


def load_scan_ids(usage):
    scan_ids = []
    if usage == "train": cur_scan_file = Param.train_scan_file
    elif usage == "test": cur_scan_file = Param.test_scan_file
    elif usage == "validate": cur_scan_file = Param.val_scan_file
    else: raise NameError
    with open(cur_scan_file, 'r') as f:
        for line in f:
            scan_ids.append(line.strip())
    return scan_ids


def load_features(scan_ids):
    viewpoint_ids = {}; features = {}
    for cur_scan_id in scan_ids:
        with open("{}/{}.json".format(Param.feature_dir, cur_scan_id), 'r') as f:
            pre_feat = json.load(f)
            cur_feat = {}
            for obs_id in pre_feat:
                temp = obs_id.split("_")
                cur_viewpoint_id = temp[0].strip(); cur_heading = int(temp[1]); elev_str = temp[-1]
                cur_elevation = int(elev_str) if not elev_str.startswith("m") else -int(elev_str[1:])
                if cur_viewpoint_id not in cur_feat: cur_feat[cur_viewpoint_id] = {}
                cur_feat[cur_viewpoint_id][(cur_heading, cur_elevation)] = np.array(pre_feat[obs_id])
            features[cur_scan_id] = cur_feat
        viewpoint_ids[cur_scan_id] = list(features[cur_scan_id].keys())
    return viewpoint_ids, features


def load_view_ids(scan_ids):
    viewpoint_ids = {}
    for cur_scan_id in scan_ids:
        with open("{}/{}.json".format(Param.feature_dir, cur_scan_id), 'r') as f:
            pre_feat = json.load(f)
            cur_viewpoint_ids = []
            for obs_id in pre_feat:
                temp = obs_id.split("_")
                cur_viewpoint_id = temp[0]
                cur_viewpoint_ids.append(cur_viewpoint_id)
            viewpoint_ids[cur_scan_id] = cur_viewpoint_ids
    return viewpoint_ids


def get_context_vec(history_vec, observe_vec, fusion_encoder=None):
    if Param.codec_type == "GRU":
        context_vec = fusion_encoder(torch.cat([history_vec, observe_vec], dim=1))
    elif Param.codec_type == "Transformer":
        context_vec = torch.cat([history_vec, observe_vec.unsqueeze(1)], dim=1)  # (batch, history len + 1, emb dim)
    else:
        raise NameError
    return context_vec


def collect_neighbor_angls(angl_ids):
    """
    get neighbor angl ids
    :param angl_ids: (batch, cand n)
    :return: (batch, cand n, 4)
    """
    neighbor_ids = torch.full((angl_ids.shape[0], angl_ids.shape[1], 4), -1)
    head_ids = torch.floor(angl_ids / 3)
    evel_ids = angl_ids % 3
    neighbor_ids[:, :, 0] = ((head_ids - 1) % 12) * 3 + evel_ids
    neighbor_ids[:, :, 1] = ((head_ids + 1) % 12) * 3 + evel_ids
    neighbor_ids[evel_ids != 0, 2] = (head_ids[evel_ids != 0] * 3 + evel_ids[evel_ids != 0] - 1)
    neighbor_ids[evel_ids != 2, 3] = (head_ids[evel_ids != 2] * 3 + evel_ids[evel_ids != 2] + 1)
    return neighbor_ids


def get_pic_idx(view_idx, angl):
    """
    :param view_idx: (batch)
    :param angl: (batch)
    :return:
    """
    pic_idx = torch.from_numpy(np.array(view_idx)) * 36
    for i, (h, e) in enumerate(angl):
        pic_idx[i] += (h // 30) * 3 + (e // 30 + 1)
    return pic_idx


def id2idx1d(ids, cand_ids):
    """
    :param ids: (batch, )
    :param cand_ids: (batch, cand len)
    :return:
    """
    idx_dict = [{cur_id: i for i, cur_id in enumerate(cur_batch_cand_ids)} for cur_batch_cand_ids in cand_ids]
    idx = []
    for cur_id, cur_idx_dict in zip(ids, idx_dict):
        idx.append(cur_idx_dict[cur_id])
    return torch.from_numpy(np.array(idx))


def id2idx2d(ids, cand_ids):
    """
    :param ids: (batch, n)
    :param cand_ids: (batch, cand len)
    :return:
    """
    idx_dict = [{cur_id: i for i, cur_id in enumerate(cur_batch_cand_ids)} for cur_batch_cand_ids in cand_ids]
    idx = []
    for cur_ids, cur_idx_dict in zip(ids, idx_dict):
        cur_idx = []
        for cur_id in cur_ids:
            cur_idx.append(cur_idx_dict[cur_id])
        idx.append(torch.from_numpy(np.array(cur_idx)))
    return torch.stack(idx, dim=0)  # (batch, n)


def get_next_step_ids(path_dict, cand_ids, end_ids):
    """
    :param path_dict: (batch, )
    :param cand_ids: (batch, cand len)
    :param end_ids: (batch, )
    :return:
    """
    paths = []
    for cur_path_dict, cur_cand_ids, cur_end_id in zip(path_dict, cand_ids, end_ids):
        cur_path = []
        for cur_cand_id in cur_cand_ids:
            cur_path.append(cur_path_dict[(cur_cand_id, cur_end_id)] if cur_cand_id != cur_end_id else None)
        paths.append(cur_path)
    return paths


def get_next_step_ids_full(path_dict, scan_ids, cand_ids, end_ids):
    paths = []
    for cur_scan_id, cur_cand_ids, cur_end_id in zip(scan_ids, cand_ids, end_ids):
        cur_path_dict = path_dict[cur_scan_id]
        cur_path = []
        for cur_cand_id in cur_cand_ids:
            cur_path.append(cur_path_dict[(cur_cand_id, cur_end_id)] if cur_cand_id != cur_end_id else None)
        paths.append(cur_path)
    return paths


def get_route_ids(path_dict, scan_ids, cand_ids, end_ids):
    paths = []  # (batch, max node num, max turn num)
    for cur_scan_id, cur_cand_ids, cur_end_id in zip(scan_ids, cand_ids, end_ids):
        cur_path_dict = path_dict[cur_scan_id]
        cur_path = []
        for cur_cand_id in cur_cand_ids:
            cur_path.append([])
            cur_id = cur_cand_id
            for t in range(Param.max_turns):
                cur_path[-1].append(cur_path_dict[(cur_id, cur_end_id)])
                if cur_path_dict[(cur_id, cur_end_id)] is not None:
                    cur_id = cur_path_dict[(cur_id, cur_end_id)]["next"]
                else:
                    cur_id = cur_end_id
        paths.append(cur_path)
    return paths

def is_state_same(states1, states2):
    is_same = []
    for cur_state1, cur_state2 in zip(states1, states2):
        cur_is_same = True
        for cur_state_token1, cur_state_token2 in zip(cur_state1, cur_state2):
            if cur_state_token1 != cur_state_token2:
                cur_is_same = False
                break
        is_same.append(cur_is_same)
    return torch.from_numpy(np.array(is_same))


def get_adj_matrix(graphs, view_ids):
    adj_matrix, idx_dicts = [], []
    for cand_view_ids in view_ids:
        idx_dicts.append(dict())
        for i, view_id in enumerate(cand_view_ids):
            idx_dicts[-1][view_id] = i
    for cand_view_ids, graph, idx_dict in zip(view_ids, graphs, idx_dicts):
        cur_adj_matrix = []
        for view_id in cand_view_ids:
            cur_adj_vec = [0 for _ in range(len(cand_view_ids))]
            for neigh_id in graph[view_id]:
                cur_adj_vec[idx_dict[neigh_id]] = 1
            cur_adj_matrix.append(cur_adj_vec)
        adj_matrix.append(cur_adj_matrix)
    return torch.from_numpy(np.array(adj_matrix))


def print_param(task_name):
    print("task = {}".format(task_name))
    print("lr = {}".format(Param.lr))
    print("batch size = {}".format(Param.batch_size))
    print("voc size = {}".format(Param.voc_size))
    print("emb dim = {}".format(Param.emb_size))
    print("sent len = {}".format(Param.sent_len))
    print("max turns = {}".format(Param.max_turns))
    print("codec_type = {}".format(Param.codec_type))
    print("trans nhead = {}".format(Param.trans_nhead))
    print("trans ff = {}".format(Param.trans_feedforward))
    print("trans dropout = {}".format(Param.trans_dropout))
    print("trans layernum = {}".format(Param.trans_layernum))
