"""Implementations of permutation-like transforms."""

import torch
import nflows.utils.typechecks as check
from nflows.transforms.permutations import Permutation


class TailRandomPermutation(Permutation):
    """Permutes using a random, but fixed, permutation. Only works with 1D inputs."""

    def __init__(self, features_light, features_heavy, dim=1):
        #num_heavy = np.sum(heavy_tails)
        #num_light = len(heavy_tails) - num_heavy
        # features_light and features_heavy are the amount of light and heavy-tailed marginals, respectively
        self.permutation = torch.cat([torch.randperm(features_light), torch.randperm(features_heavy) + features_light])
        super().__init__(self.permutation, dim)

    def get_permutation(self):
        return self.permutation

class RandomPermutation(Permutation):
    """Permutes using a random, but fixed, permutation. Only works with 1D inputs."""

    def __init__(self, features, dim=1):
        if not check.is_positive_int(features):
            raise ValueError("Number of features must be a positive integer.")
        self.permutation = torch.randperm(features)
        super().__init__(self.permutation, dim)

    def get_permutation(self):
        return self.permutation
