from pathlib import Path

import h5py
import torch
import torch.nn.functional as F


selected_categories = [
    "03691459",  # 8436
    "02828884",  # 6778
    "04530566",  # 4045
    "03636649",  # 3514
    "04090263",  # 3173
    "04256520",  # 2373
    "02958343",  # 2318
    "02691156",  # 1939
    "03001627",  # 1813
    "04379243",  # 1597
]


class NeoMLPDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        dataset_dir,
        dataset_file,
        split="train",
        splits=None,
        num_out_embeddings=3,
        normalize=False,
        augmentation=False,
        permutation=False,
        translation_scale=0.25,
        rotation_degree=45,
        noise_scale=0.1,
        drop_rate=0.01,
        resize_scale=0.2,
        quantile_dropout=0.0,
    ):
        self.split = split

        data, targets = torch.load(Path(dataset_dir) / dataset_file)
        data = torch.from_numpy(data).float()
        targets = torch.from_numpy(targets).long()

        train_set_size = splits[0]
        val_set_size = splits[1]

        if split == "train":
            data = data[:train_set_size]
            targets = targets[:train_set_size]
        elif split == "val":
            data = data[train_set_size : train_set_size + val_set_size]
            targets = targets[train_set_size : train_set_size + val_set_size]
        elif split == "test":
            data = data[train_set_size + val_set_size :]
            targets = targets[train_set_size + val_set_size :]
        else:
            raise ValueError("Invalid split")

        self.dataset = torch.utils.data.TensorDataset(data, targets)

        self.augmentation = augmentation
        self.permutation = permutation
        self.normalize = normalize
        # if self.normalize:
        #     statistics_path = (
        #         (Path(dataset_dir) / Path(statistics_path)).expanduser().resolve()
        #     )
        #     self.stats = torch.load(statistics_path, map_location="cpu")

        self.translation_scale = translation_scale
        self.rotation_degree = rotation_degree
        self.noise_scale = noise_scale
        self.drop_rate = drop_rate
        self.resize_scale = resize_scale
        self.quantile_dropout = quantile_dropout

        self.num_out_embeddings = num_out_embeddings

    def __len__(self):
        return len(self.dataset)

    def set_stats(self, mu, std):
        self.stats = {"mean": mu, "std": std}

    def _normalize(self, embeddings):
        emb_mean = self.stats["mean"]
        emb_std = self.stats["std"]

        embeddings = (embeddings - emb_mean) / emb_std

        return embeddings

    @staticmethod
    def rotation_mat(degree=30.0):
        angle = torch.empty(1).uniform_(-degree, degree)
        angle_rad = angle * (torch.pi / 180)
        rotation_matrix = torch.tensor(
            [
                [torch.cos(angle_rad), -torch.sin(angle_rad)],
                [torch.sin(angle_rad), torch.cos(angle_rad)],
            ]
        )
        return rotation_matrix

    def _augment(self, embeddings):
        """Augmentations for NeoMLP"""
        # # translation
        # translation = torch.empty(weights[0].shape[0]).uniform_(
        #     -self.translation_scale, self.translation_scale
        # )
        # new_biases[i] += translation @ weights[0]

        # # rotation
        # if new_weights[0].shape[0] == 2:
        #     rot_mat = self.rotation_mat(self.rotation_degree)
        #     new_weights[0] = rot_mat @ new_weights[0]

        # # scale
        # rand_scale = 1 + (torch.rand(1).item() - 0.5) * 2 * self.resize_scale
        # new_weights[0] = new_weights[0] * rand_scale

        # Noise
        new_embeddings = embeddings + embeddings.std() * self.noise_scale

        # Dropout
        new_embeddings = F.dropout(new_embeddings, p=self.drop_rate)

        # Quantile dropout
        if self.quantile_dropout > 0:
            do_q = torch.empty(1).uniform_(0, self.quantile_dropout)
            q = torch.quantile(new_embeddings.flatten(), q=do_q)
            new_embeddings = torch.where(embeddings.abs() < q, 0, new_embeddings)

        return new_embeddings

    def _permute(self, embeddings):
        num_hidden_embeddings = embeddings.shape[0] - self.num_out_embeddings

        perm = torch.randperm(num_hidden_embeddings)
        new_embeddings = torch.cat(
            [
                embeddings[perm],
                embeddings[-self.num_out_embeddings :],
            ],
            dim=0,
        )

        return new_embeddings

    def __getitem__(self, item):
        embeddings, label = self.dataset[item]

        if self.augmentation:
            embeddings = self._augment(embeddings)

        if self.normalize:
            embeddings = self._normalize(embeddings)

        if self.permutation:
            embeddings = self._permute(embeddings)

        return embeddings, label

    # def __getitem__(self, item):
    #     # Read from hdf5
    #     with h5py.File(self.data_file, "r") as f:
    #         embeddings = f["embeddings"][item]
    #         label = f["labels"][item]
    #         return embeddings, label


class NeoMLPDatasetSequential(torch.utils.data.Dataset):
    def __init__(
        self,
        dataset_dir,
        dataset_file,
        split="train",
        splits=None,
        num_out_embeddings=3,
        num_augmentations=1,
        normalize=False,
        augmentation=False,
        permutation=False,
        translation_scale=0.25,
        rotation_degree=45,
        noise_scale=0.1,
        drop_rate=0.01,
        resize_scale=0.2,
        quantile_dropout=0.0,
    ):
        self.split = split
        self.num_augmentations = num_augmentations
        self.dataset_dir = dataset_dir
        self.dataset_file = dataset_file
        self.splits = splits
        self.set_dataset(0)

        self.augmentation = augmentation
        self.permutation = permutation
        self.normalize = normalize
        # if self.normalize:
        #     statistics_path = (
        #         (Path(dataset_dir) / Path(statistics_path)).expanduser().resolve()
        #     )
        #     self.stats = torch.load(statistics_path, map_location="cpu")

        self.translation_scale = translation_scale
        self.rotation_degree = rotation_degree
        self.noise_scale = noise_scale
        self.drop_rate = drop_rate
        self.resize_scale = resize_scale
        self.quantile_dropout = quantile_dropout

        self.num_out_embeddings = num_out_embeddings

    def __len__(self):
        return len(self.dataset)

    def set_stats(self, mu, std):
        self.stats = {"mean": mu, "std": std}

    def set_dataset(self, epoch):
        if self.split in ["train", "val"]:
            augmentation = epoch % self.num_augmentations
            data, targets = torch.load(Path(self.dataset_dir) / f"{self.dataset_file}_{augmentation}.pt")
        else:
            data, targets = torch.load(Path(self.dataset_dir) / self.dataset_file)
        data = torch.from_numpy(data).float()
        targets = torch.from_numpy(targets).long()

        train_set_size = self.splits[0]

        if self.split == "train":
            data = data[:train_set_size]
            targets = targets[:train_set_size]
        elif self.split == "val":
            data = data[train_set_size :]
            targets = targets[train_set_size :]

        self.dataset = torch.utils.data.TensorDataset(data, targets)

    def _normalize(self, embeddings):
        emb_mean = self.stats["mean"]
        emb_std = self.stats["std"]

        embeddings = (embeddings - emb_mean) / emb_std

        return embeddings

    @staticmethod
    def rotation_mat(degree=30.0):
        angle = torch.empty(1).uniform_(-degree, degree)
        angle_rad = angle * (torch.pi / 180)
        rotation_matrix = torch.tensor(
            [
                [torch.cos(angle_rad), -torch.sin(angle_rad)],
                [torch.sin(angle_rad), torch.cos(angle_rad)],
            ]
        )
        return rotation_matrix

    def _augment(self, embeddings):
        """Augmentations for NeoMLP"""
        # # translation
        # translation = torch.empty(weights[0].shape[0]).uniform_(
        #     -self.translation_scale, self.translation_scale
        # )
        # new_biases[i] += translation @ weights[0]

        # # rotation
        # if new_weights[0].shape[0] == 2:
        #     rot_mat = self.rotation_mat(self.rotation_degree)
        #     new_weights[0] = rot_mat @ new_weights[0]

        # # scale
        # rand_scale = 1 + (torch.rand(1).item() - 0.5) * 2 * self.resize_scale
        # new_weights[0] = new_weights[0] * rand_scale

        # Noise
        new_embeddings = embeddings + embeddings.std() * self.noise_scale

        # Dropout
        new_embeddings = F.dropout(new_embeddings, p=self.drop_rate)

        # Quantile dropout
        if self.quantile_dropout > 0:
            do_q = torch.empty(1).uniform_(0, self.quantile_dropout)
            q = torch.quantile(new_embeddings.flatten(), q=do_q)
            new_embeddings = torch.where(embeddings.abs() < q, 0, new_embeddings)

        return new_embeddings

    def _permute(self, embeddings):
        num_hidden_embeddings = embeddings.shape[0] - self.num_out_embeddings

        perm = torch.randperm(num_hidden_embeddings)
        new_embeddings = torch.cat(
            [
                embeddings[perm],
                embeddings[-self.num_out_embeddings :],
            ],
            dim=0,
        )

        return new_embeddings

    def __getitem__(self, item):
        embeddings, label = self.dataset[item]

        if self.augmentation:
            embeddings = self._augment(embeddings)

        if self.normalize:
            embeddings = self._normalize(embeddings)

        if self.permutation:
            embeddings = self._permute(embeddings)

        return embeddings, label


class NeoMLPDatasetH5(NeoMLPDataset):
    def __init__(
        self,
        dataset_dir,
        dataset_file,
        split="train",
        splits=None,
        num_out_embeddings=3,
        normalize=False,
        augmentation=False,
        permutation=False,
        translation_scale=0.25,
        rotation_degree=45,
        noise_scale=0.1,
        drop_rate=0.01,
        resize_scale=0.2,
        quantile_dropout=0.0,
    ):
        self.split = split

        all_data = h5py.File(Path(dataset_dir) / dataset_file, 'r')

        data = torch.from_numpy(all_data['embeddings'][:]).float()
        targets = torch.from_numpy(all_data['labels'][:]).long()

        unique_labels, counts = targets.unique(return_counts=True)
        keep_labels = unique_labels[counts.sort(descending=True).indices[:10]]
        keep_indices = (targets[:, None] == keep_labels).any(1).nonzero().flatten()

        data = data[keep_indices]
        targets = torch.where(targets[:, None] == keep_labels)[1]

        train_set_size = splits[0]
        val_set_size = splits[1]

        if split == "train":
            data = data[:train_set_size]
            targets = targets[:train_set_size]
        elif split == "val":
            data = data[train_set_size : train_set_size + val_set_size]
            targets = targets[train_set_size : train_set_size + val_set_size]
        elif split == "test":
            data = data[train_set_size + val_set_size :]
            targets = targets[train_set_size + val_set_size :]
        else:
            raise ValueError("Invalid split")

        self.dataset = torch.utils.data.TensorDataset(data, targets)

        self.augmentation = augmentation
        self.permutation = permutation
        self.normalize = normalize
        # if self.normalize:
        #     statistics_path = (
        #         (Path(dataset_dir) / Path(statistics_path)).expanduser().resolve()
        #     )
        #     self.stats = torch.load(statistics_path, map_location="cpu")

        self.translation_scale = translation_scale
        self.rotation_degree = rotation_degree
        self.noise_scale = noise_scale
        self.drop_rate = drop_rate
        self.resize_scale = resize_scale
        self.quantile_dropout = quantile_dropout

        self.num_out_embeddings = num_out_embeddings
