import torch.utils.data as data
import os.path as osp
from datasets.pc_ops import *
from utils.data_utils import *
from utils.realistic_projection import Realistic_Projection
import h5py

# ALL ScanObjectNN LABELS
SONN_label_dict = {
    "bag": 0, "bin": 1, "box": 2,
    "cabinet": 3, "chair": 4, "desk": 5,
    "display": 6, "door": 7, "shelf": 8,
    "table": 9, "bed": 10, "pillow": 11,
    "sink": 12, "sofa": 13, "toilet": 14
}

# all ScanObjectNN categories but merging desk and table to same class
sonn_all = {
    0: 0,  # "bag"
    1: 1,  # "bin"
    2: 2,  # "box"
    3: 3,  # "cabinet"
    4: 4,  # "chair"
    5: 5,  # "desk" (merged with table)
    6: 6,  # "display"
    7: 7,  # "door"
    8: 8,  # "shelf"
    9: 5,  # "table" (merged with desk)
    10: 9,  # "bed"
    11: 10,  # "pillow"
    12: 11,  # "sink"
    13: 12,  # "sofa"
    14: 13  # "toilet"
}

# modelnet_set1 ==> ScanObjectNN
sonn_2_mdSet1 = {
    4: 0,  # chair
    8: 1,  # shelf
    7: 2,  # door
    12: 3,  # sink
    13: 4  # sofa
}

# modelnet_set2 ==> ScanObjectNN
sonn_2_mdSet2 = {
    10: 0,  # bed
    14: 1,  # toilet
    5: 2,  # desk
    6: 3,  # display
    9: 2  # table
}

# common ood set
# these are categories with poor mapping between ModelNet and ScanObjectNN
sonn_ood_common = {
    0: 404,  # bag
    1: 404,  # bin
    2: 404,  # box
    3: 404,  # cabinet
    11: 404  # pillow
}

################################
# for real -> real experiments #
################################
SR12 = {
    4: 0,  # chair
    8: 1,  # shelf
    7: 2,  # door
    12: 3,  # sink
    13: 4,  # sofa
    ######
    10: 5,  # bed
    14: 6,  # toilet
    5: 7,  # desk
    9: 7,  # table
    6: 8,  # display
}

SR13 = {
    4: 0,  # chair
    8: 1,  # shelf
    7: 2,  # door
    12: 3,  # sink
    13: 4,  # sofa
    ######
    0: 5,  # bag
    1: 6,  # bin
    2: 7,  # box
    3: 8,  # cabinet
    11: 9  # pillow
}

SR23 = {
    10: 0,  # bed
    14: 1,  # toilet
    5: 2,  # desk
    9: 2,  # table
    6: 3,  # display
    ######
    0: 4,  # bag
    1: 5,  # bin
    2: 6,  # box
    3: 7,  # cabinet
    11: 8,  # pillow
}


################################


def load_h5_data_label(h5_path):
    f = h5py.File(h5_path, 'r')
    curr_data = f['data'][:]
    curr_label = f['label'][:]
    f.close()
    return np.asarray(curr_data, dtype=np.float32), np.asarray(curr_label, dtype=np.int64)


def load_h5_data_label_list(h5_paths):
    curr_data = []
    curr_label = []
    for curr_h5 in h5_paths:
        f = h5py.File(curr_h5, 'r')
        curr_data.extend(f['data'][:])
        curr_label.extend(f['label'][:])
        f.close()
    return np.asarray(curr_data, dtype=np.float32), np.asarray(curr_label, dtype=np.int64)


class ScanObject(data.Dataset):
    def __init__(
            self,
            data_root=None,
            sonn_split="main_split",
            h5_file="objectdataset.h5",
            split="train",
            class_choice=None,
            num_points=2048,
            transforms=None,
            pretrain=True):
        self.whoami = "ScanObject"
        assert split in ['train', 'test', 'all']
        self.split = split
        self.data_dir = osp.join(data_root, "ScanObjectNN/h5_files")
        assert osp.exists(self.data_dir), f"{self.whoami} - {self.data_dir} does not exist"
        self.num_points = num_points
        assert self.num_points <= 2048, "num_points must be less than 2048"
        self.transforms = transforms
        self.sonn_split = sonn_split
        self.h5_file = h5_file
        self.class_choice = class_choice
        self.pretrain = pretrain
        self.projector = Realistic_Projection()

        if self.split == "train":
            h5_file_path = [osp.join(self.data_dir, sonn_split, f"training_{h5_file}")]
        elif self.split == "test":
            h5_file_path = [osp.join(self.data_dir, sonn_split, f"test_{h5_file}")]
        elif self.split == "all":
            h5_file_path = [osp.join(self.data_dir, sonn_split, f"training_{h5_file}"),
                            osp.join(self.data_dir, sonn_split, f"test_{h5_file}")]
        else:
            raise ValueError(f"Wrong ScanObjectNN split: {self.split}")

        # LOAD ALL DATA IN MEMORY
        if isinstance(h5_file_path, list):
            self.points, self.labels = load_h5_data_label_list(h5_file_path)
        else:
            self.points, self.labels = load_h5_data_label(h5_file_path)

        # CLASS CHOICE
        if self.class_choice is not None:
            if isinstance(self.class_choice, str):
                self.class_choice = eval(self.class_choice)
            if not isinstance(self.class_choice, dict):
                raise ValueError(f"{self.whoami} - cannot load conversion dict with class_choice: {class_choice}")

            chosen_idxs = [index for index, value in enumerate(self.labels) if value in self.class_choice.keys()]
            assert len(chosen_idxs) > 0
            self.points = self.points[chosen_idxs]
            self.labels = [self.class_choice[self.labels[idx]] for idx in chosen_idxs]
            self.labels = np.asarray(self.labels, dtype=np.int64)
            print("label space:", list(set(self.labels)))
            self.num_classes = len(set(self.class_choice.values()))
        else:
            self.num_classes = len(SONN_label_dict.keys())

        print(f"ScanObject - "
              f"num_points: {self.num_points}, "
              f"sonn_split: {self.sonn_split}, "
              f"split: {self.split}, "
              f"class_choice: {self.class_choice}, "
              f"points shape: {self.points.shape}")

        if self.pretrain:
            if class_choice == "sonn_2_mdSet1":
                self.encoded_prompt = torch.load(f'./prompts/encoded_sonn_SR1.pt', map_location='cpu')
                #self.encoded_prompt = torch.load(f'../prompts/encoded_sonn_SR1.pt', map_location='cpu')
            elif class_choice == "sonn_2_mdSet2":
                self.encoded_prompt = torch.load(f'./prompts/encoded_sonn_SR2.pt', map_location='cpu')
                #self.encoded_prompt = torch.load(f'../prompts/encoded_sonn_SR2.pt', map_location='cpu')
            elif class_choice == "SR12":
                self.encoded_prompt = torch.load(f'./prompts/encoded_SR12.pt', map_location='cpu')
            elif class_choice == "SR13":
                self.encoded_prompt = torch.load(f'./prompts/encoded_SR13.pt', map_location='cpu')
            elif class_choice == "SR23":
                self.encoded_prompt = torch.load(f'./prompts/encoded_SR23.pt', map_location='cpu')
            self.points = torch.from_numpy(self.points).detach()

    def get_depth_map(self, pc):
        depth = self.projector.get_img(pc).detach()
        depth = torch.nn.functional.interpolate(depth, size=(224, 224), mode='bilinear', align_corners=True)
        return depth

    def __getitem__(self, index):
        point_set = self.points[index]
        assert len(point_set) == 2048, "ScanObjectNN: expected 2048-points input shape"
        label = self.labels[index]

        # pretrain with CLIP
        if self.pretrain:
            text = self.encoded_prompt[label]
            depth = self.get_depth_map(point_set.unsqueeze(0))
            # data augmentation
            point_set = normalize(point_set)
            point_set = random_shift(point_set)
            point_set = random_rotate(point_set)
            point_set = random_jitter(point_set)
            point_set = random_rotate_perturbation(point_set)
            return point_set, text, depth, label

        # train an autoencoder
        else:
            # sampling
            point_set = random_sample(points=point_set, num_points=self.num_points)
            # unit cube normalization
            point_set = pc_normalize(point_set)
            # data augmentation
            if self.transforms:
                point_set = self.transforms(point_set)
        return point_set

    def __len__(self):
        return len(self.points)


