import os

import h5py
import numpy as np
from torch.utils import data


category_list = [
    ("04379243", "table"),
    ("03593526", "jar"),
    ("04225987", "skateboard"),
    ("02958343", "car"),
    ("02876657", "bottle"),
    ("04460130", "tower"),
    ("03001627", "chair"),
    ("02871439", "bookshelf"),
    ("02942699", "camera"),
    ("02691156", "airplane"),
    ("03642806", "laptop"),
    ("02801938", "basket"),
    ("04256520", "sofa"),
    ("03624134", "knife"),
    ("02946921", "can"),
    ("04090263", "rifle"),
    ("04468005", "train"),
    ("03938244", "pillow"),
    ("03636649", "lamp"),
    ("02747177", "trash bin"),
    ("03710193", "mailbox"),
    ("04530566", "watercraft"),
    ("03790512", "motorbike"),
    ("03207941", "dishwasher"),
    ("02828884", "bench"),
    ("03948459", "pistol"),
    ("04099429", "rocket"),
    ("03691459", "loudspeaker"),
    ("03337140", "file cabinet"),
    ("02773838", "bag"),
    ("02933112", "cabinet"),
    ("02818832", "bed"),
    ("02843684", "birdhouse"),
    ("03211117", "display"),
    ("03928116", "piano"),
    ("03261776", "earphone"),
    ("04401088", "telephone"),
    ("04330267", "stove"),
    ("03759954", "microphone"),
    ("02924116", "bus"),
    ("03797390", "mug"),
    ("04074963", "remote"),
    ("02808440", "bathtub"),
    ("02880940", "bowl"),
    ("03085013", "keyboard"),
    ("03467517", "guitar"),
    ("04554684", "washer"),
    ("02834778", "bicycle"),
    ("03325088", "faucet"),
    ("04004475", "printer"),
    ("02954340", "cap"),
    # We added these manually
    ("02992529", "cellphone"),
    ("03046257", "clock"),
    ("03513137", "helmet"),
    ("03761084", "microwave"),
    ("03991062", "flowerpot"),
]

CORRUPTED_MODELS = [
    ("03691459", "a76f63c6b3702a4981c9b20aad15512.npz"),
    ("03001627", "a5d21835219c8fed19fb4103277a6b93.npz"),
]


class ShapeNetSDF(data.Dataset):
    def __init__(
        self,
        root: str,
        seed: int = 42,
        debug: bool = False,
        train: bool = True,
    ):
        """
        Args:
            root (str): dataset root, place the 'ShapeNet' folder in this directory
            split (str): which split is used, ['train', 'test', 'val']
        """
        # Attributes
        self.root = root

        # Get ShapeNet categories
        categories = [c for c in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, c))]

        # Construct metadata
        self.metadata = {
            c: {"id": c, "name": "n/a", 'idx': i}
            for i, c in enumerate(categories)
        }

        # Get all models
        self.models = [
            {
                "category": category,
                "model": model,
            }
            for category in categories
            for model in sorted(os.listdir(os.path.join(self.root, category)))
            if (category, model) not in CORRUPTED_MODELS
        ]

        # shuffle the samples
        rng = np.random.default_rng(seed)
        rng.shuffle(self.models)

    def __len__(self):
        return len(self.models)

    def __getitem__(self, idx):
        file_path = os.path.join(
            self.root, self.models[idx]["category"], self.models[idx]["model"]
        )

        # # Get label
        # label = self.metadata[self.models[idx]["category"]]["idx"]

        # Get points
        points_dict = np.load(file_path)

        # Select points and indices
        points = points_dict["pts"]
        sdf = points_dict["sdf"]
        return points, sdf, idx


class ShapeNetSDFH5(data.Dataset):
    def __init__(
        self,
        root: str,
        seed: int = 42,
        debug: bool = False,
        train: bool = True,
    ):
        self.root = root
        if train:
            self.file_path = os.path.join(self.root, "shapenet_train.h5")
        else:
            self.file_path = os.path.join(self.root, "shapenet_test.h5")

        self.data = h5py.File(self.file_path, "r")
        self.num_signals = self.data["points"].shape[0]
        self.points_per_signal = self.data["points"].shape[1]
        self.total_num_points = self.num_signals * self.points_per_signal

        self.points = self.data["points"][...]
        self.sdf = self.data["sdf"][...]
        self.indices = self.data["indices"][:].astype(int)
        if not train:
            self.indices = self.indices - self.indices.min()
        self.labels = self.data["labels"][:]

    def __len__(self):
        return self.total_num_points

    def __getitem__(self, idx):
        point_idx, point_idx_in_signal = divmod(idx, self.points_per_signal)
        points = self.points[point_idx, point_idx_in_signal, :]
        sdf = self.sdf[point_idx, point_idx_in_signal]
        idx = self.indices[point_idx]

        # # Get label
        # label = self.metadata[self.models[idx]["category"]]["idx"]
        return points, sdf, idx


class ChunkedShapeNetSDFH5(data.Dataset):
    def __init__(
        self,
        root: str,
        seed: int = 42,
        debug: bool = False,
        train: bool = True,
    ):
        self.root = root
        if train:
            self.file_path = os.path.join(self.root, "shapenet_train.h5")
        else:
            self.file_path = os.path.join(self.root, "shapenet_test.h5")

        self.data = h5py.File(self.file_path, "r")
        self.num_signals = self.data["points"].shape[0]
        self.points_per_signal = self.data["points"].shape[1]
        self.total_num_points = self.num_signals * self.points_per_signal

        self.indices = self.data["indices"][:].astype(int)
        if not train:
            self.indices = self.indices - self.indices.min()
        self.labels = self.data["labels"][:]

    def __len__(self):
        return self.num_signals

    def __getitem__(self, idx):
        points = self.data["points"][idx]
        sdf = self.data["sdf"][idx]
        idx = self.indices[idx]
        # Repeat indices
        idx = np.repeat(idx, self.points_per_signal)

        # # Get label
        # label = self.metadata[self.models[idx]["category"]]["idx"]
        return points, sdf, idx

class RandomBatchSampler(data.Sampler[list[int]]):
    def __init__(self, size: int, batch_size: int, seed: int = 0, drop_last: bool = True):
        self.size = size
        self.batch_size = batch_size
        self.seed = seed
        self.drop_last = drop_last

        self.rng = np.random.default_rng(seed)

    def __len__(self):
        if self.drop_last:
            return self.size // self.batch_size
        else:
            return (self.size + self.batch_size - 1) // self.batch_size

    def __iter__(self):
        while True:
            yield self.rng.choice(self.size, size=self.batch_size, replace=True)


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    train_dataset = ShapeNetSDF(root="datasets/ShapeNetSDF", seed=0)

    # batch_size = 1024
    # h5_dataset = ShapeNetSDFH5(root="datasets/ShapeNetSDF")
    # sampler = RandomBatchSampler(len(h5_dataset), batch_size, drop_last=True)
    # dataloader = data.DataLoader(h5_dataset, batch_sampler=sampler, num_workers=18)
    # sampler = data.RandomSampler(h5_dataset, replacement=True)
    # dataloader = data.DataLoader(h5_dataset, sampler=sampler, batch_size=batch_size, num_workers=18)

    chunked_train_dataset = ChunkedShapeNetSDFH5(root="datasets/ShapeNetSDF")
    chunked_loader = data.DataLoader(chunked_train_dataset, batch_size=1, shuffle=False,
                                     num_workers=4)
    print("Labels:", chunked_train_dataset.labels[:16])

    fig = plt.figure()
    for i in range(16):
        ax = fig.add_subplot(4, 4, i+1, projection="3d")
        points, sdf, idx = chunked_train_dataset[i]
        points = points[sdf < 0]
        sdf = sdf[sdf < 0]

        ax.scatter(points[:, 0], points[:, 2], points[:, 1], c=sdf, s=5, cmap="coolwarm")
        print("Index:", idx)
        print("Points:", points.shape)
        print("SDF:", sdf.shape)

    plt.savefig("test.png")
