from RoutePlan import *
import torch
from torch.utils.data import Dataset, DataLoader
from Util import *
import random
from Param import *


class FullEnvDataset(Dataset):
    def __init__(self, usage):
        self.usage = usage
        self.scan_ids = load_scan_ids(usage)
        self.view_ids, self.features = load_features(self.scan_ids)
        self.graphs = {}
        self.graph_paths = {}

    def __getitem__(self, item):
        cur_scan_id = random.choice(self.scan_ids)
        if cur_scan_id not in self.graphs: self.graphs[cur_scan_id] = self.load_graph(cur_scan_id)[cur_scan_id]
        if cur_scan_id not in self.graph_paths: self.graph_paths[cur_scan_id] = RoutePlan.floyd(self.graphs[cur_scan_id])
        # cur_view_ids = list(self.graphs[cur_scan_id].keys())
        return cur_scan_id

    def __len__(self):
        return 500

    def get_view_ids(self, scan_ids):
        view_ids = []
        for scan_id in scan_ids:
            if scan_id not in self.graphs: self.graphs[scan_id] = self.load_graph(scan_id)[scan_id]
            if scan_id not in self.graph_paths: self.graph_paths[scan_id] = RoutePlan.floyd(self.graphs[scan_id])
            view_ids.append(list(self.graphs[scan_id].keys()))
        return view_ids

    def load_graph(self, scan_id):
        with open("{}/{}_graph.json".format(Param.graph_dir, scan_id), 'r') as f:
            graph = json.load(f)
        node_list = list(graph[scan_id].keys())
        for node in node_list:
            if len(graph[scan_id][node]) == 0:
                del graph[scan_id][node]
        return graph

    def collect_single_obs_single_view(self, scan_ids, view_ids, poses):
        """
        :param scan_ids: (batch)
        :param view_ids:  (batch)
        :param poses:  (batch)
        :return:  (batch, feat len)
        """
        obses = []
        for cur_scan_id, cur_view_id, cur_angl in zip(scan_ids, view_ids, poses):
            obses.append(torch.from_numpy(self.features[cur_scan_id][cur_view_id][cur_angl]))
        obses = torch.stack(obses, dim=0)  # (batch size, feat len)
        return obses

    def collect_single_obs_multi_views(self, scan_ids, view_ids, poses):
        """
        :param scan_ids: (batch)
        :param view_ids: (batch)
        :param poses: (batch)  collect relative pos obs
        :return: (batch, 36, feat len)
        """
        obses = []
        for cur_scan_id, cur_view_id, (cur_h, cur_e) in zip(scan_ids, view_ids, poses):
            cur_view_obs = []
            for h in range(0 + cur_h, 360 + cur_h, 30):
                for e in [-30 + cur_e - 30, 0 + cur_e - 30, 30 + cur_e - 30]:
                    cur_view_obs.append(torch.from_numpy(self.features[cur_scan_id][cur_view_id][(h % 360, e % 90 - 30)]))
            cur_view_obs = torch.stack(cur_view_obs, dim=0)
            obses.append(cur_view_obs)
        obses = torch.stack(obses, dim=0)
        return obses

    # def collect_multi_obs_single_view(self, scan_ids, view_ids, poses, cand_n=Param.max_num_node):
    #     """
    #     :param scan_ids: (batch)
    #     :param view_ids: (batch, cand n)
    #     :param poses: (batch, cand n)
    #     :return: (batch, cand n, feat len)
    #     """
    #     raise NotImplementedError # TODO, view_ids lengths are not the same
    #     obses = []
    #     for cur_scan_id, cur_batch_view_ids, cur_batch_angls in zip(scan_ids, view_ids, poses):
    #         cur_obses = []
    #         for cur_view_id, cur_angl in zip(cur_batch_view_ids, cur_batch_angls):
    #             cur_obses.append(torch.from_numpy(self.features[cur_scan_id][cur_view_id][cur_angl]))
    #         cur_obses = torch.stack(cur_obses)  # (cand n, feat len)
    #         obses.append(cur_obses)
    #     obses = torch.stack(obses, dim=0)  # (batch, cand n, feat len)
    #     obses = obses.to(torch.float32)
    #     return obses

    def collect_obs(self, scan_ids, view_ids, cand_n=Param.max_num_node):
        """
        :param scan_ids:
        :param view_ids: (batch, cand n)
        :return: (batch, cand n, 36, feat len)
        """
        obses = []
        for cur_scan_id, cur_view_ids in zip(scan_ids, view_ids):
            cur_scan_obs = []
            for cur_view_id in cur_view_ids:
                cur_view_obs = []
                for h in range(0, 360, 30):
                    for e in [-30, 0, 30]:
                        cur_view_obs.append(torch.from_numpy(self.features[cur_scan_id][cur_view_id][(h, e)]))
                cur_view_obs = torch.stack(cur_view_obs, dim=0)  # (36, feat len)
                cur_scan_obs.append(cur_view_obs)
            cur_scan_obs = torch.stack(cur_scan_obs, dim=0)  # (num view, 36, feat len)
            cur_scan_obs = torch.cat([cur_scan_obs, torch.zeros((cand_n - cur_scan_obs.shape[0], 36, 1000))])
            obses.append(cur_scan_obs)
        obses = torch.stack(obses, dim=0)
        obses = obses.to(torch.float32)
        return obses
