"""Dataset and data loader for ScanNet object classification."""

import os

import numpy as np
import torch
from torch.utils.data import Dataset

from scannet.model_util_scannet import ScannetDatasetConfig
from scannet.scannet_utils import read_label_mapping
from src.scannet_dataset import unpickle_data, save_data

import ipdb
st = ipdb.set_trace

NUM_CLASSES = 485
DC = ScannetDatasetConfig(NUM_CLASSES, agnostic=True)
MAX_NUM_OBJ = 132


class ScanNetClsDataset(Dataset):
    """Dataset utilities for classification on ScanNet."""

    def __init__(self, split='train', num_points=2048,
                 use_color=False, use_size=False,
                 use_height=False, use_normals=False, overfit=False):
        """Initialize dataset."""
        self.split = split
        self.num_points = num_points
        self.use_color = use_color
        self.use_size = use_size
        self.use_height = use_height
        self.use_normals = use_normals
        self.overfit = overfit

        self.data_path = './dataset/language_grounding/'
        print('Loading %s files, take a breath!' % split)
        split = 'val' if split != 'train' else 'train'
        if not os.path.exists(f'{self.data_path}/%s_v2scans.pkl' % split):
            save_data(f'{self.data_path}/%s_v2scans.pkl' % split, split)
        _, self.scans = unpickle_data(
            f'{self.data_path}/%s_v2scans.pkl' % split
        )

        self.mean_rgb = np.array([109.8, 97.2, 83.8]) / 256
        self.size_max = np.asarray([0.00916539, 0.00956155, 0.00850233])
        self.size_min = np.asarray([11.12648844, 17.15573245, 3.99937487])
        self.size_mean = np.asarray([0.77537238, 0.95590994, 0.863069])
        self.label_map = read_label_mapping(
            'scannet/meta_data/scannetv2-labels.combined.tsv',
            label_from='raw_category',
            label_to='nyu40id' if NUM_CLASSES == 18 else 'id'
        )

        self.annos = self.load_annos()

    def load_annos(self):
        """Load annotations."""
        split = 'train' if self.split == 'train' else 'val'
        with open('scannet/meta_data/scannetv2_%s.txt' % split) as f:
            scan_ids = [line.rstrip() for line in f]

        annos = [
            (scan_id, obj)
            for scan_id in scan_ids
            for obj in range(len(self.scans[scan_id].three_d_objects))
            if self.label_map[
                self.scans[scan_id].get_object_instance_label(obj)
            ] in DC.nyu40id2class
        ]
        if self.overfit:
            annos = annos[:128]
        return annos

    def __getitem__(self, index):
        """Get current batch for input index."""
        # Get point cloud features
        (scan_id, obj_id) = self.annos[index]
        scan = self.scans[scan_id]
        point_cloud = np.copy(scan.get_object_pc(obj_id))
        label = DC.nyu40id2class[self.label_map[
            scan.get_object_instance_label(obj_id)
        ]]

        # Add color
        color = None
        if self.use_color:
            color = np.copy(scan.get_object_color(obj_id))

        # Add size
        size = None
        if self.use_size:
            box = scan.get_object_bbox(obj_id)
            size = box[3:] - box[:3]

        # Add height
        height = None
        if self.use_height:
            floor_height = np.percentile(point_cloud[:, 2], 0.99)
            height = point_cloud[:, 2] - floor_height

        # Add normals
        normals = None
        if self.use_normals:
            normals = np.load(os.path.join(
                "./dataset/language_grounding/extra/scannet_normals",
                scan_id + '.npy'
            ))
            points_to_use = np.copy(scan.three_d_objects[obj_id]['points'])
            normals = normals[points_to_use]

        # Sample pointcloud and color
        point_cloud, color, height, normals = self._subsample_pc(
            point_cloud, color, height, normals
        )

        # Normalize
        point_cloud, color, size = self._normalize(point_cloud, color, size)

        # Augment
        if self.split == 'train':
            point_cloud, color, size = self._augment(point_cloud, color, size)

        # Concatenate
        point_cloud = self._concat(point_cloud, color, size, height, normals)
        return {
            "scan_ids": scan_id,
            "object_ids": obj_id,
            "point_clouds": torch.from_numpy(point_cloud).float(),
            "obj_labels": torch.as_tensor(label).long()
        }

    def __len__(self):
        """Return number of utterances."""
        return len(self.annos)

    @staticmethod
    def _augment(point_cloud, color, size):
        # Flipping along the YZ plane
        if np.random.random() > 0.5:
            point_cloud[:, 0] = -point_cloud[:, 0]
        # Flipping along the XZ plane
        if np.random.random() > 0.5:
            point_cloud[:, 1] = -point_cloud[:, 1]

        # Random rotation
        theta = 2 * np.random.rand() - 1
        point_cloud = rot_x(point_cloud, theta)
        theta = 2 * np.random.rand() - 1
        point_cloud = rot_y(point_cloud, theta)
        theta = np.random.rand() * 360
        point_cloud = rot_z(point_cloud, theta)

        # Add noise
        noise = np.random.rand(len(point_cloud), 3) * 5e-3
        point_cloud = point_cloud + noise

        # Random scale
        point_cloud = point_cloud * np.random.uniform(0.8, 1.25)
        # Random shift
        shifts = np.random.uniform(-0.1, 0.1, 3)
        point_cloud = point_cloud + shifts[None, :]

        # Change illumination
        if color is not None:
            color = color * np.random.uniform(0.8, 1.25)
            color = color * 0.98 + 0.04*np.random.random((len(color), 3))

        # Rescale size
        if size is not None:
            size = size * np.random.uniform(0.8, 1.25)
        return point_cloud, color, size

    @staticmethod
    def _concat(pc, color=None, size=None, height=None, normals=None):
        if color is not None:
            pc = np.concatenate((pc, color), 1)
        if size is not None:
            size = np.tile(size, (len(pc), 1))
            pc = np.concatenate((pc, size), 1)
        if height is not None:
            pc = np.concatenate([pc, np.expand_dims(height, 1)], 1)
        if normals is not None:
            pc = np.concatenate((pc, normals), 1)
        return pc

    def _normalize(self, pc, color=None, size=None):
        pc = pc - np.mean(pc, axis=0)
        max_dist = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
        pc = pc / max_dist
        if color is not None:
            color = color - self.mean_rgb
        if size is not None:
            size = (size - self.size_mean) / (self.size_max - self.size_min)
        return pc, color, size

    def _subsample_pc(self, pc, color=None, height=None, normals=None):
        if self.split != 'train':
            np.random.seed(1184)
        choices = np.random.choice(
            pc.shape[0],
            self.num_points,
            replace=len(pc) < self.num_points
        )
        pc = pc[choices]
        if color is not None:
            color = color[choices]
        if height is not None:
            height = height[choices]
        if normals is not None:
            normals = normals[choices]
        return pc, color, height, normals


def rot_x(pc, theta):
    """Rotate along x-axis."""
    theta = theta * np.pi / 180
    return np.matmul(
        np.array([
            [1.0, 0, 0],
            [0, np.cos(theta), -np.sin(theta)],
            [0, np.sin(theta), np.cos(theta)]
        ]),
        pc.T
    ).T


def rot_y(pc, theta):
    """Rotate along y-axis."""
    theta = theta * np.pi / 180
    return np.matmul(
        np.array([
            [np.cos(theta), 0, np.sin(theta)],
            [0, 1.0, 0],
            [-np.sin(theta), 0, np.cos(theta)]
        ]),
        pc.T
    ).T


def rot_z(pc, theta):
    """Rotate along z-axis."""
    theta = theta * np.pi / 180
    return np.matmul(
        np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta), np.cos(theta), 0],
            [0, 0, 1.0]
        ]),
        pc.T
    ).T
