import argparse
import os
import random

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# parser = argparse.ArgumentParser(description='Make the labeled-unlabeled split')
# parser.add_argument('--setting_dir', default='', metavar='DIR',
#                     help='Path to save N split benchmark')
# parser.add_argument('--split', default=20, type=int, metavar='N', help='Number of time steps')
# parser.add_argument('--label_ratio', default=0.01, type=float, help='Label rate per time step')
# args = parser.parse_args()



# split = args.split
# datadir = args.setting_dir
split = 20
datadir = '/home/zhanw0g/dataset/i10k/20splitCL'
# label_rate = args.label_ratio
label_rate = 0.01
transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
for task in range(split):
    taskdir = f'{datadir}/{task}/train'
    dataset = datasets.ImageFolder(taskdir, transform=transform)
    labeled = []
    index_cls = {}
    for i, t in enumerate(dataset.targets):
        if t not in index_cls.keys():
            index_cls[t] = []
        index_cls[t].append(i)
    for c in index_cls.keys():
        cls_idx = index_cls[c]
        cls_size = len(cls_idx)
        random.shuffle(cls_idx)
        labeled.extend(cls_idx[:int(cls_size * label_rate)])
    size = len(dataset.targets)
    print(f'task {task} total length {size}')
    print(f"task {task} labeled length {len(labeled)}")
    # save the index of labeled sampels and unlabeled samples
    unlabeled = torch.tensor(list(set(range(size)) - set(labeled)), dtype=torch.long)
    labeled = torch.tensor(labeled, dtype=torch.long)
    os.makedirs(f'{datadir}/{task}/{label_rate}labeled', exist_ok=True)
    torch.save(labeled, f'{datadir}/{task}/{label_rate}labeled/labeled_task{task}.buf')
    torch.save(unlabeled, f'{datadir}/{task}/{label_rate}labeled/unlabeled_task{task}.buf')
