import torch
import numpy as np
from collections import Counter

def get_permutation(original_len, device='auto', seed=0):
    # if device == 'auto':
    #     generator = torch.Generator(device='cpu')
    # else:
    #     generator = torch.Generator(device=device)
    generator = None
    perm = torch.randperm(original_len, generator=generator)
    return perm


def balance_dataset_indexes(y, shuffle=True):
    classes = np.unique(y)
    n_i = max(Counter(y).values())
    idx_balanced = []
    for k in classes:
        idx_k = np.where(y == k)[0]
        n_k = len(idx_k)
        if n_k < n_i:
            # Oversample the class by duplicating some of its samples
            idx_oversampled = np.random.choice(idx_k, size=n_i - n_k, replace=True)
            idx_k_balanced = np.concatenate([idx_k, idx_oversampled], axis=0)
        else:
            idx_k_balanced = idx_k
        idx_balanced.append(idx_k_balanced)
    idx_balanced = np.concatenate(idx_balanced, axis=0)
    if shuffle: np.random.shuffle(idx_balanced)
    return idx_balanced


def generate_categorical_samples(n, k, probs=None):
    # If no probabilities are provided, generate random probabilities
    if probs is None:
        probs = torch.rand((k,))
    else:
        assert len(probs) == k, "Length of probs must equal k"

    if isinstance(probs, list):
        probs = torch.tensor(probs)
    # Generate a tensor of shape (n, k) containing the probabilities for each class
    probs = probs.repeat((n, 1))

    # Normalize the probabilities so that they sum to 1 along the second dimension
    probs /= torch.sum(probs, dim=1, keepdim=True)

    # Generate a tensor of shape (n,) containing random samples from the categorical distribution
    samples = torch.multinomial(probs, num_samples=1).squeeze()

    return samples


def random_indices(num_indices, max_idx):
    """Generates a set of num_indices random indices (without replacement)
    between 0 and max_idx - 1.

    Args:
        num_indices (int): Number of indices to include.
        max_idx (int): Maximum index.
    """
    # It is wasteful to compute the entire permutation, but it looks like
    # PyTorch does not have other functions to do this
    permutation = torch.randperm(max_idx)
    # Select first num_indices indices (this will be random since permutation is
    # random)
    return permutation[:num_indices]


def compute_split_idx(original_len, split_sizes, k_fold=-1):
    all_idx = torch.arange(original_len)
    if len(split_sizes) == 1:
        return [all_idx]
    if k_fold >= 0:
        generator = torch.Generator(device='cpu')
        generator.manual_seed(42)
        perm = torch.randperm(original_len, generator=generator)
        all_idx = all_idx[perm]
        n = len(perm) * (1 - split_sizes[0])
        all_idx = torch.roll(all_idx, shifts=int(n*k_fold))
    start_idx, end_idx = 0, None
    all_idx_splits = []

    num_splits = len(split_sizes)
    for i, size in enumerate(split_sizes):
        assert isinstance(size, float)
        assert 0 < size
        assert 1 > size
        new_len = int(size * original_len)
        end_idx = new_len + start_idx
        if i == (num_splits - 1):
            all_idx_splits.append(all_idx[start_idx:])
        else:
            all_idx_splits.append(all_idx[start_idx:end_idx])
        start_idx = end_idx

    return all_idx_splits

def compute_split_idx_random(original_len, split_sizes, k_fold=-1):
    all_idx = torch.arange(original_len)
    if len(split_sizes) == 1:
        return [all_idx]
    if k_fold >= 0:
        generator = torch.Generator(device='cpu')
        generator.manual_seed(k_fold)
        perm = torch.randperm(original_len, generator=generator)
        all_idx = all_idx[perm]

    start_idx, end_idx = 0, None
    all_idx_splits = []

    num_splits = len(split_sizes)
    for i, size in enumerate(split_sizes):
        assert isinstance(size, float)
        assert 0 < size
        assert 1 > size
        new_len = int(size * original_len)
        end_idx = new_len + start_idx
        if i == (num_splits - 1):
            all_idx_splits.append(all_idx[start_idx:])
        else:
            all_idx_splits.append(all_idx[start_idx:end_idx])
        start_idx = end_idx

    return all_idx_splits


def num_classes_fn(data) -> int:
    r"""Returns the number of classes in the dataset."""
    y = data.y
    if y is None:
        return 0
    elif y.numel() == y.size(0) and not torch.is_floating_point(y):
        return int(data.y.max()) + 1
    elif y.numel() == y.size(0) and torch.is_floating_point(y):
        return torch.unique(y).numel()
    else:
        return data.y.size(-1)


def normal(loc, scale, shape, device='auto', seed=0):
    # if device == 'auto':
    #     generator = torch.Generator(device='cpu')
    # else:
    #     generator = torch.Generator(device=device)
    # generator.manual_seed(seed)
    generator = None

    normal = torch.distributions.Normal(loc=loc, scale=scale)
    shape = normal._extended_shape(shape)
    with torch.no_grad():
        return torch.normal(normal.loc.expand(shape), normal.scale.expand(shape), generator=generator)


def bernoulli(probs, device='auto', seed=0):
    # if device == 'auto':
    #     generator = torch.Generator(device='cpu')
    # else:
    #     generator = torch.Generator(device=device)
    # generator.manual_seed(seed)
    generator = None
    return torch.bernoulli(probs, generator=generator).to(torch.bool)
