import os
import os.path as osp

import numpy as np
import tqdm
import h5py
from torch.utils import data
from utils.data_utils import *
from datasets.pc_ops import *
from utils.realistic_projection import Realistic_Projection

modelnet40_label_dict = {
    'airplane': 0, 'bathtub': 1, 'bed': 2, 'bench': 3, 'bookshelf': 4, 'bottle': 5, 'bowl': 6,
    'car': 7, 'chair': 8, 'cone': 9, 'cup': 10, 'curtain': 11, 'desk': 12, 'door': 13,
    'dresser': 14, 'flower_pot': 15, 'glass_box': 16, 'guitar': 17, 'keyboard': 18, 'lamp': 19,
    'laptop': 20, 'mantel': 21, 'monitor': 22, 'night_stand': 23, 'person': 24, 'piano': 25,
    'plant': 26, 'radio': 27, 'range_hood': 28, 'sink': 29, 'sofa': 30, 'stairs': 31, 'stool': 32,
    'table': 33, 'tent': 34, 'toilet': 35, 'tv_stand': 36, 'vase': 37, 'wardrobe': 38, 'xbox': 39}

modelnet10_label_dict = {
    'bathtub': 0, 'bed': 1, 'chair': 2, 'desk': 3, 'dresser': 4, 'monitor': 5, 'night_stand': 6,
    'sofa': 7, 'table': 8, 'toilet': 9}

############################################
# Closed Set for Modelnet to SONN experiments

SR1 = {
    "chair": 0,
    "bookshelf": 1,
    "door": 2,
    "sink": 3,
    "sofa": 4
}

SR2 = {
    "bed": 0,
    "toilet": 1,
    "desk": 2,
    "monitor": 3,
    "table": 2
}


# these are always OOD samples in cross-domain experiments!
modelnet_set3 = {
    'bathtub': 404,  # 1,  # simil sink???
    'bottle': 404,  # 5,
    'bowl': 404,  # 6,
    'cup': 404,  # 10,
    'curtain': 404,  # 11,
    'plant': 404,  # 26,  # simil bin???
    'flower_pot': 404,  # 15,  # simil bin???
    'vase': 404,  # 37,  # simil bin???
    'guitar': 404,  # 17,
    'keyboard': 404,  # 18,
    'lamp': 404,  # 19,
    'laptop': 404,  # 20,
    'night_stand': 404,  # 23,  # simil table - hard out-of-distrib.?
    'person': 404,  # 24,
    'piano': 404,  # 25,  # simil table - hard out-of-distrib.?
    'radio': 404,  # 27,
    'stairs': 404,  # 31,
    'tent': 404,  # 34,
    'tv_stand': 404,  # 36,  # simil table - hard out-of-distrib.?
}


class ModelNet40_OOD(data.Dataset):
    """
    ModelNet40 normal resampled. 10k sampled points for each shape.
    """
    def __init__(self, data_root, num_points, transforms=None, train=True, class_choice="SR1", pretrain=True):
        super().__init__()
        self.whoami = "ModelNet40_OOD"
        self.split = "train" if train else "test"
        self.num_points = min(int(1e4), num_points)
        self.transforms = transforms
        self.pretrain = pretrain
        self.projector = Realistic_Projection()

        assert isinstance(class_choice, str) and class_choice.startswith('SR'), f"{self.whoami} - class_choice must be SRX name"
        self.class_choice = eval(class_choice)
        assert isinstance(self.class_choice, dict)
        self.num_classes = len(set(self.class_choice.values()))
        # reading data
        self.data_dir = os.path.join(data_root, "modelnet40_normal_resampled")
        if not osp.exists(self.data_dir):
            raise FileNotFoundError(f"{self.whoami} - {self.data_dir} does not exist")
        # cache
        cache_dir = osp.join(self.data_dir, "ood_sets_cache")  # directory containing cache files
        cache_fn = osp.join(cache_dir, f'{class_choice}_{self.split}.h5')  # path to cache file
        if os.path.exists(cache_fn):
            # read from cache file
            print(f"{self.whoami} - Reading data from h5py file: {cache_fn}")
            f = h5py.File(cache_fn, 'r')
            self.points = np.asarray(f['data'][:], dtype=np.float32)
            self.labels = np.asarray(f['label'][:], dtype=np.int64)
            f.close()
        else:
            # reading from txt files and building cache for next training/evaluation
            split_file = os.path.join(self.data_dir, f"modelnet40_{self.split}.txt")
            # all paths
            shape_ids = [
                line.rstrip()
                for line in open(
                    os.path.join(self.data_dir, split_file)
                )
            ]

            # all names
            shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids]

            # class choice
            chosen_idxes = [index for index, name in enumerate(shape_names) if name in self.class_choice.keys()]
            self.shape_ids = [shape_ids[_] for _ in chosen_idxes]
            self.shape_names = [shape_names[_] for _ in chosen_idxes]
            del shape_ids, shape_names

            # read chosen data samples from disk
            self.datapath = [
                (
                    self.shape_names[i],
                    os.path.join(self.data_dir, self.shape_names[i], self.shape_ids[i])
                    + ".txt",
                )
                for i in range(len(self.shape_ids))
            ]
            self.points = []
            self.labels = []
            for i in tqdm.trange(len(self.datapath), desc=f"{self.whoami} loading data from txts", dynamic_ncols=True):
                fn = self.datapath[i]
                point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
                point_set = point_set[:, 0:3]  # remove normals
                category_name = self.shape_names[i]  # 'airplane'
                cls = self.class_choice[category_name]
                self.points.append(point_set)  # [1, 10000, 3]
                self.labels.append(cls)

            self.points = np.stack(self.points, axis=0)  # [num_samples, 10000, 3]
            self.labels = np.asarray(self.labels, dtype=np.int64)  # [num_samples, ]

            # make cache
            if not osp.exists(cache_dir):
                os.makedirs(cache_dir)
            print(f"Saving h5py dataset to: {cache_fn}")

            with h5py.File(cache_fn, "w") as f:
                f.create_dataset(name='data', data=self.points, dtype=np.float32, chunks=True)
                f.create_dataset(name='label', data=self.labels, dtype=np.int64, chunks=True)

            print(f"{self.whoami} - Cache built for split: {self.split}, set: {self.class_choice} - "
                  f"datas: {self.points.shape} labels: {self.labels.shape} ")

        print(f"{self.whoami} {self.split} - points shape: {self.points.shape}")
        print(f"{self.whoami} {self.split} - categories: {self.class_choice}")
        print(f"{self.whoami} {self.split} - sampled points: {self.num_points}")
        print(f"{self.whoami} {self.split} - transforms: {self.transforms}")

        if self.pretrain:
            #self.encoded_prompt = torch.load(f'./datasets/encoded_{class_choice}.pt', map_location='cpu')
            self.encoded_prompt = torch.load(f'encoded_{class_choice}.pt', map_location='cpu')
            self.points = torch.from_numpy(self.points).detach()#.to('mps')
            #print("===> Sampling...")
            #self.points = batch_farthest_point_sample(self.points, self.num_points).to('cpu')
            #print("===> Sampling complete!")

    def get_depth_map(self, pc):
        depth = self.projector.get_img(pc).detach()
        depth = torch.nn.functional.interpolate(depth, size=(224, 224), mode='bilinear', align_corners=True)
        return depth

    def __getitem__(self, idx):
        point_set = self.points[idx]
        lbl = self.labels[idx]

        # pretrain with CLIP
        if self.pretrain:
            text = self.encoded_prompt[lbl]
            depth = self.get_depth_map(point_set.unsqueeze(0))
            # data augmentation
            point_set = random_sample(point_set, num_points=self.num_points)
            point_set = normalize(point_set)
            point_set = random_shift(point_set)
            point_set = random_rotate(point_set)
            point_set = random_jitter(point_set)
            point_set = random_rotate_perturbation(point_set)
            return point_set, text, depth, lbl

        # train an autoencoder
        else:
            # sampling
            point_set = random_sample(point_set, num_points=self.num_points)
            # unit cube normalization
            point_set = pc_normalize(point_set)
            # data augmentation
            if self.transforms:
                point_set = self.transforms(point_set)
            return point_set

    def __len__(self):
        return len(self.labels)


class H5_Dataset(data.Dataset):
    """
    Simple H5 dataset for ModelNet-C.
    dirty data!!!
    idx: 1800
    file path: /ModelNet40_corrupted/occlusion/modelnet_set2_train_occlusion_sev2.h5
    """
    def __init__(self, h5_file, num_points, class_choice="SR1", transforms=None, pretrain=True):
        super().__init__()
        if not os.path.exists(h5_file):
            raise FileNotFoundError(h5_file)
        self.h5_file = h5_file
        self.num_points = num_points
        self.transforms = transforms
        self.pretrain = pretrain
        self.projector = Realistic_Projection()
        # load from h5 file
        print(f"Reading data from hdf5 file: {self.h5_file}", end='')
        f = h5py.File(self.h5_file, 'r')
        self.points = np.asarray(f['data'][:], dtype=np.float32)
        self.labels = np.asarray(f['label'][:], dtype=np.int64)
        if "modelnet_set2_train_occlusion_sev2.h5" in self.h5_file:
            self.points = np.delete(self.points, 1800, axis=0)
            self.labels = np.delete(self.labels, 1800, axis=0)
        f.close()
        print(f" datas: {self.points.shape}, labels: {self.labels.shape}")

        if self.pretrain:
            self.encoded_prompt = torch.load(f'./prompts/encoded_{class_choice}.pt', map_location='cpu')
            # self.encoded_prompt = torch.load(f'encoded_{class_choice}.pt', map_location='cpu')
            self.points = torch.from_numpy(self.points).detach()

    def get_depth_map(self, pc):
        depth = self.projector.get_img(pc).detach()
        depth = torch.nn.functional.interpolate(depth, size=(224, 224), mode='bilinear', align_corners=True)
        return depth

    def __getitem__(self, idx):
        point_set = self.points[idx]
        lbl = self.labels[idx]

        # pretrain with CLIP
        if self.pretrain:
            text = self.encoded_prompt[lbl]
            depth = self.get_depth_map(point_set.unsqueeze(0))
            # data augmentation
            point_set = normalize(point_set)
            point_set = random_shift(point_set)
            point_set = random_rotate(point_set)
            point_set = random_jitter(point_set)
            point_set = random_rotate_perturbation(point_set)
            point_set = random_sample(point_set)
            return point_set, text, depth, lbl

        # train an autoencoder
        else:
            # sampling
            point_set = random_sample(point_set, num_points=self.num_points)
            # unit cube normalization
            point_set = pc_normalize(point_set)
            # data augmentation
            if self.transforms:
                point_set = self.transforms(point_set)
            return point_set

    def __len__(self):
        return len(self.labels)


