import numpy as np
import random
import torch
import torchvision
from torch.utils.data import Dataset
import time

def fix_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_median_index(values):
    values_array = np.array(values)
    median_value = np.median(values_array)
    closest_index = np.argmin(np.abs(values_array - median_value))

    return closest_index


def random_fourier_transform(Y: torch.tensor, original_dim, new_dim, W_t=None, b_t=None):
    if W_t is None and b_t is None:
        W = (torch.randn(original_dim, new_dim) * np.sqrt(2 / new_dim)).cuda()
        b = (torch.rand(1, new_dim) * 2 * np.pi).cuda()
        Z = torch.cos(Y @ W + b)
        print("Z.shape: ", Z.shape)
        return Z, W ,b
    elif W_t is not None and b_t is not None:
        Z = torch.cos(Y @ W_t + b_t)
        return Z


def random_matrix(Y: torch.tensor, original_dim, new_dim, W_t=None, b_t=None):
    if W_t is None and b_t is None:
        W = (torch.randn(original_dim, new_dim)).cuda()
        b = 0
        Z = Y @ W 
        return Z, W, b
    elif W_t is not None and b_t is not None:
        Z = Y @ W_t + b_t
        return Z

def random_project(Y: torch.tensor, original_dim, new_dim, W_t=None, b_t=None):
    if W_t is None and b_t is None:
        W = (torch.randn(original_dim, new_dim)).cuda()
        _, _, U = torch.svd(W)
        Q = U[:, :original_dim]
        print("W: ", W.shape, " U: ", U.shape, " Q: ",  Q.shape)
        Z = Y @ Q.T
        print("Z.shape: ", Z.shape)
        return Z, Q.T ,0
    elif W_t is not None and b_t is not None:
        Z = Y @ W_t # W_t = Q.T
        return Z


def index_extraction(len_dataset, p_data):
    indices = np.arange(len_dataset)
    
    np.random.shuffle(indices)
    
    index_extracted = np.random.choice(indices, size=int(p_data * len_dataset), replace=False)

    return index_extracted



def split_dataset(len_dataset, num=3):

    indices = np.arange(len_dataset)
    np.random.shuffle(indices)
    split_index = np.array_split(indices, num)
    return split_index


def split_dataset_with_possible_overlap(len_dataset, p_overlap, num=3, e_round=1):
    indices = np.arange(len_dataset)

    seed = int(time.time()) + e_round
    np.random.seed(seed)
    
    np.random.shuffle(indices)
    
    subset_size = len_dataset // num 

    subsets = []
    
    # sampling without replacement
    subsets = np.array_split(indices, num)

    # sampling with replacement
    if p_overlap > 0 and num > 1:
        for _ in range(num):
            for i in range(num):
                if i != _:
                    extra_indices_i = np.random.choice(subsets[i], size=int(p_overlap * subset_size), replace=False)
                    subsets[_] = np.hstack((subsets[_], extra_indices_i))
            
    # print(subsets[0])
    subsets_array = [np.array(subset) for subset in subsets]
    
    return subsets_array


class CustomSubset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
        self.data = self.dataset.data[indices]
        self.targets = np.array(self.dataset.targets)[indices]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        image, target = self.dataset[self.indices[idx]]
        return image, target

def get_subset_indices(dataset, num_per_class):
   
    targets = np.array(dataset.targets)
    indices = []
    for class_idx in range(10):  # CIFAR-10 有10个类别
        class_indices = np.where(targets == class_idx)[0]
        selected_indices = class_indices[:num_per_class]
        indices.extend(selected_indices)
    return indices

def subset_original(train_dataset, ipc=500):
   
    train_indices = get_subset_indices(train_dataset, ipc)
    train_subset = CustomSubset(train_dataset, train_indices)
    return train_subset