import os
import os.path as osp
import json
import h5py
import torch
from torch.utils.data import Dataset, DataLoader
from utils.data_utils import *
from datasets.sncore_splits import *

from datasets.pc_ops import *
from utils.realistic_projection import Realistic_Projection


class ShapeNetCore4k(Dataset):
    def __init__(self,
                 data_root=None,
                 split="train",
                 class_choice=None,
                 num_points=2048,
                 transforms=None,
                 apply_fix_cellphone=True,
                 pretrain=True):

        self.whoami = "ShapeNetCore4k"
        self.points = None  # all point clouds from split
        self.synset_ids = None  # for each shape its synset id
        self.model_ids = None  # for each shape its model id
        assert split.lower() in ['train', 'test', 'val']
        self.split = split
        self.pc_dim = 4096
        self.data_dir = osp.join(data_root, "sncore_fps_4096")
        assert osp.exists(self.data_dir), f"{self.whoami} - {self.data_dir} does not exist"
        self.class_choice = class_choice
        self.num_points = num_points
        self.transforms = transforms
        self.pretrain = pretrain
        self.projector = Realistic_Projection()

        # load data split
        self.load_split()

        # silent pycharm warnings
        assert self.points is not None
        assert self.synset_ids is not None
        assert self.model_ids is not None

        # sub-select point clouds with synset choice
        if self.class_choice:
            # a list of synset Ids is expected for category selection
            assert isinstance(self.class_choice, list), \
                f"{self.whoami} {self.split} - class_choice should be a list of synset ids"
            chosen_idxes = [index for index, s_id in enumerate(self.synset_ids) if s_id in self.class_choice]
            assert len(chosen_idxes) > 0, f"ShapeNetCore4k {self.split} - No samples for class choice"
            self.synset_ids = [self.synset_ids[i] for i in chosen_idxes]
            self.model_ids = [self.model_ids[i] for i in chosen_idxes]
            self.points = self.points[chosen_idxes]

        if apply_fix_cellphone:
            # merge "cellphone" with "telephone"
            cellphone_sid = "02992529"
            telephone_sid = "04401088"
            cell_idxes = [index for index, s_id in enumerate(self.synset_ids) if s_id == cellphone_sid]
            if len(cell_idxes):
                print(f"{self.whoami} {self.split} - merging cellphone with telephone")
            for j in cell_idxes:
                # substitute synset_id of cellphones with telephone one
                self.synset_ids[j] = telephone_sid

        unique_ids = list(set(self.synset_ids))
        unique_ids.sort()
        self.num_classes = len(unique_ids)
        self.id_2_label = dict(zip(unique_ids, list(range(self.num_classes))))
        self.id_2_name = dict(zip(unique_ids, [sncore_all_synset[id] for id in unique_ids]))
        self.labels = np.asarray([self.id_2_label[s_id] for s_id in self.synset_ids])
        print(f"{self.whoami} {self.split} - points shape: {self.points.shape}, synset_ids: {len(self.synset_ids)}")
        print(f"{self.whoami} {self.split} - id_2_name: {self.id_2_name}")
        print(f"{self.whoami} {self.split} - id_2_label: {self.id_2_label}")
        print(f"{self.whoami} {self.split} - sampled points: {self.num_points}")
        print(f"{self.whoami} {self.split} - transforms: {self.transforms}")

        if self.pretrain:
            self.encoded_prompt = torch.load('./prompts/encoded_sncore_prompt.pt', map_location='cpu')
            # self.encoded_prompt = torch.load('encoded_sncore_prompt.pt', map_location='cpu')
            self.points = torch.from_numpy(self.points).detach()#.to('mps')
            # print("===> Sampling...")
            # self.points = batch_farthest_point_sample(self.points, self.num_points).to('cpu')
            # print("===> Sampling complete!")

    def load_split(self):
        s_ids_fn = osp.join(self.data_dir, f"sncore_{self.split}_{self.pc_dim}_sids.json")
        assert osp.exists(s_ids_fn), f"synset ids file {s_ids_fn} does not exist"
        with open(s_ids_fn, 'r') as f:
            self.synset_ids = json.load(f)

        m_ids_fn = osp.join(self.data_dir, f"sncore_{self.split}_{self.pc_dim}_mids.json")
        assert osp.exists(m_ids_fn), "model ids file does not exist"
        with open(m_ids_fn, 'r') as f:
            self.model_ids = json.load(f)

        points_fn = osp.join(self.data_dir, f"sncore_{self.split}_{self.pc_dim}_points.h5")
        assert osp.exists(points_fn), "points file does not exist"
        f = h5py.File(points_fn, 'r')

        self.points = f['data'][:].astype('float32')
        f.close()

    def get_depth_map(self, pc):
        depth = self.projector.get_img(pc).detach()
        depth = torch.nn.functional.interpolate(depth, size=(224, 224), mode='bilinear', align_corners=True)
        return depth

    def __getitem__(self, item):
        point_set = self.points[item]
        synset_id = self.synset_ids[item]
        synset_category = sncore_all_synset[synset_id]

        # pretrain with CLIP
        if self.pretrain:
            text = self.encoded_prompt[synset_category]
            depth = self.get_depth_map(point_set.unsqueeze(0))
            # data augmentation
            point_set = normalize(point_set)
            point_set = random_shift(point_set)
            point_set = random_rotate(point_set)
            point_set = random_jitter(point_set)
            point_set = random_rotate_perturbation(point_set)
            return point_set, text, depth, self.labels[item]

        # train an autoencoder
        else:
            # sampling
            point_set = random_sample(point_set, num_points=self.num_points)
            # unit cube normalization
            point_set = pc_normalize(point_set)
            # data augmentation
            if self.transforms:
                point_set = self.transforms(point_set)
            return point_set

    def __len__(self):
        return len(self.points)


class ShapeNetCorrupted(Dataset):
    def __init__(self,
                 data_root=None,
                 split="test",
                 class_choice=None,
                 transforms=None,
                 apply_fix_cellphone=True,
                 corruption="lidar",
                 sev=None,
                 num_points=-1  # unused, compatibility with ShapeNetCore4k
                 ):

        self.whoami = "ShapeNetCorrupted"
        if sev is None:
            sev = [1, 2, 3, 4, 5]
        self.sev = sev
        self.corruption = corruption

        assert split.lower() in ['train', 'test', 'val']
        self.split = split
        self.data_dir = osp.join(data_root, "sncore_fps_4096", "sncore_corrupted_v2")
        assert osp.exists(self.data_dir), f"{self.whoami} - {self.data_dir} does not exist"
        self.class_choice = class_choice
        self.transforms = transforms
        self.points = None  # all point clouds from split
        self.synset_ids = None  # for each shape its synset id
        self.model_ids = None  # for each shape its model id

        # load data split
        self.load_split_sev()

        # sub-select point clouds with synset choice
        if self.class_choice:
            # a list of synset Ids is expected for category selection
            assert isinstance(self.class_choice, list), \
                f"{self.whoami} {self.split} - class_choice should be a list of synset ids"
            chosen_idxes = [index for index, s_id in enumerate(self.synset_ids) if s_id in self.class_choice]
            assert len(chosen_idxes) > 0, f"ShapeNetCore4k {self.split} - No samples for class choice"
            self.synset_ids = [self.synset_ids[i] for i in chosen_idxes]
            self.model_ids = [self.model_ids[i] for i in chosen_idxes]
            self.points = self.points[chosen_idxes]

        if apply_fix_cellphone:
            # merge "cellphone" with "telephone"
            cellphone_sid = "02992529"
            telephone_sid = "04401088"
            cell_idxes = [index for index, s_id in enumerate(self.synset_ids) if s_id == cellphone_sid]
            if len(cell_idxes):
                print(f"{self.whoami} {self.split} - merging cellphone with telephone")
            for j in cell_idxes:
                # substitute synset_id of cellphones with telephone one
                self.synset_ids[j] = telephone_sid

        unique_ids = list(set(self.synset_ids))
        unique_ids.sort()
        self.num_classes = len(unique_ids)
        self.id_2_label = dict(zip(unique_ids, list(range(self.num_classes))))
        self.labels = np.asarray([self.id_2_label[s_id] for s_id in self.synset_ids])
        print(f"{self.whoami} {self.split} - corruption: {self.corruption}")
        print(f"{self.whoami} {self.split} - points: {self.points.shape}, synset_ids: {len(self.synset_ids)}")
        print(f"{self.whoami} {self.split} - id_2_label: {self.id_2_label}")
        print(f"{self.whoami} {self.split} - transforms: {self.transforms}")

    def load_split_sev(self):
        assert isinstance(self.sev, list)
        self.points = []
        self.synset_ids = []
        self.model_ids = []

        for curr_sev in self.sev:
            s_ids_fn = osp.join(self.data_dir,
                                f"{self.corruption}/sncore_{self.corruption}_sev{curr_sev}_{self.split}_sids.json")
            m_ids_fn = osp.join(self.data_dir,
                                f"{self.corruption}/sncore_{self.corruption}_sev{curr_sev}_{self.split}_mids.json")
            points_fn = osp.join(self.data_dir,
                                 f"{self.corruption}/sncore_{self.corruption}_sev{curr_sev}_{self.split}_points.h5")

            assert osp.exists(s_ids_fn), f"synset ids file {s_ids_fn} does not exist"
            with open(s_ids_fn, 'r') as f:
                self.synset_ids.extend(json.load(f))

            assert osp.exists(m_ids_fn), "model ids file does not exist"
            with open(m_ids_fn, 'r') as f:
                self.model_ids.extend(json.load(f))

            assert osp.exists(points_fn), "points file does not exist"
            f = h5py.File(points_fn, 'r')
            self.points.append(f['data'][:].astype('float32'))
            f.close()

        self.points = np.concatenate(self.points, 0)

    def __getitem__(self, item):
        point_set = self.points[item]
        lbl = self.labels[item]
        synset_id = self.synset_ids[item]
        model_id = self.model_ids[item]

        # unit cube normalization
        point_set = pc_normalize(point_set)

        # data augmentation
        if self.transforms:
            point_set = self.transforms(point_set)

        return point_set, lbl, model_id, synset_id

    def __len__(self):
        return len(self.points)


