import h5py
import numpy as np
from tqdm import tqdm

def sample_dataset(file_name,num_sample_train,num_test):

    def get_samples_efficiently(dataset, indices, batch_size=1000):
        result = []
        total_batches = len(indices) // batch_size + (1 if len(indices) % batch_size != 0 else 0)
        for i in tqdm(range(total_batches), desc=""):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, len(indices))
            batch_indices = indices[start_idx:end_idx]
            
         
            mask = np.zeros(len(dataset), dtype=bool)
            mask[batch_indices] = True
            batch_samples = dataset[mask]
            
            result.append(batch_samples)
        
        return np.concatenate(result) if result else np.array([])

    with h5py.File(file_name,'r') as file:
        dataset_train = file['train']
        dataset_test = file['test']
        total_sample = dataset_train.shape[0]
        if num_sample_train > total_sample:
            raise ValueError(f"{num_sample_train} {total_sample}")
        
        random_indices_train = np.random.choice(total_sample,num_sample_train,replace=False)
        random_indices_test = np.random.choice(len(dataset_test),num_test,replace=False)
        
        random_indices_train.sort()
        random_indices_test.sort()
   
        select_train = get_samples_efficiently(dataset_train, random_indices_train)
        select_test = get_samples_efficiently(dataset_test, random_indices_test)
        
    return select_train,select_test

def get_neighbor_and_distance(select_train,select_test):

    neighbors_indices = np.zeros((len(select_test), 100), dtype=np.int32)
    neighbors_distances = np.zeros((len(select_test), 100))
    
    select_test_norm = np.linalg.norm(select_test,axis=1,keepdims=True)
    select_train_norm = np.linalg.norm(select_train,axis=1,keepdims=True)
    
    select_train = select_train / (select_train_norm + 1e-10)
    select_test = select_test / (select_test_norm + 1e-10)
    
    for i in tqdm(range(len(select_test)),desc=""):
        test_sample = select_test[i]
        similarity = np.dot(select_train,test_sample)
        similarity = np.clip(similarity,-1.0,1.0)
        angular = np.arccos(similarity)
        
        top_indices = np.argsort(angular)[:100]
        neighbors_indices[i] = top_indices
        neighbors_distances[i] = angular[top_indices]
        
    return neighbors_indices,neighbors_distances

if __name__ == '__main__':
    file_name = ''
    out_file_name = ''
    num_sample_train = 500000
    num_sample_test = 10000
    select_train,select_test = sample_dataset(file_name,num_sample_train,num_sample_test)
    neighbors_indices,neighbors_distances = get_neighbor_and_distance(select_train=select_train,select_test=select_test)
    
    with h5py.File(out_file_name, 'w') as out_file:
        out_file.create_dataset('train', data=select_train)
        out_file.create_dataset('test', data=select_test)
        out_file.create_dataset('neighbors', data=neighbors_indices)
        out_file.create_dataset('distances', data=neighbors_distances)

        
        
        
        