import os
import shutil
from torchvision import datasets
from sklearn.model_selection import StratifiedShuffleSplit

from tqdm import tqdm

# Paths to the original datasets
original_train_dir = '/data/clean/imagenet/train'
original_val_dir = '/data/clean/imagenet/test'

# Paths to the new dataset
new_dataset_dir = '/data/sam_data/data/imagenet'
new_train_dir = os.path.join(new_dataset_dir, 'train')
new_val_dir = os.path.join(new_dataset_dir, 'val')
new_test_dir = os.path.join(new_dataset_dir, 'test')

# Create new dataset directories
os.makedirs(new_train_dir, exist_ok=True)
os.makedirs(new_val_dir, exist_ok=True)
#os.makedirs(new_test_dir, exist_ok=True)

# Copy the original validation set to the new test directory
shutil.copytree(original_val_dir, new_test_dir, symlinks=True)

# Load the original training dataset
train_dataset = datasets.ImageFolder(original_train_dir)
class_names = train_dataset.classes
class_to_idx = train_dataset.class_to_idx

# Extract file paths and labels
file_paths = [item[0] for item in train_dataset.samples]
labels = [item[1] for item in train_dataset.samples]

# Stratified split to ensure class balance
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
train_indices, val_indices = next(splitter.split(file_paths, labels))

# Create a function to create symlinks to the new directory structure
def create_symlinks(file_paths, labels, indices, target_dir):
    for idx in tqdm(indices, total=len(indices)):
        file_path = file_paths[idx]
        label = labels[idx]
        class_name = class_names[label]
        target_class_dir = os.path.join(target_dir, class_name)
        os.makedirs(target_class_dir, exist_ok=True)
        symlink_path = os.path.join(target_class_dir, os.path.basename(file_path))
        os.symlink(file_path, symlink_path)

# Create symlinks for training files (90%) in the new training directory
create_symlinks(file_paths, labels, train_indices, new_train_dir)

# Create symlinks for validation files (10%) in the new validation directory
create_symlinks(file_paths, labels, val_indices, new_val_dir)

print("Dataset partitioning with symlinks completed.")
