import os, sys, shutil
import random
import torch
"""
Usage:
  python stanford40.py /data/stanford40
"""

def read(filename):
    with open(filename) as f:
        return f.readlines()


def split_data(nclass,ntrain,ntest):
    n = random.randint(nclass // 4, nclass)  # sample a subset of classes
    # classes = list(random.sample(list(range(nclass)), n))
    classes = list(range(nclass))
    val_indices = []
    train_indices = []

    m = ntrain // nclass
    for c in classes:
        indices = list(range(c * m, (c + 1) * m))
        for _ in range(m // 10):
            val_indices.append(indices.pop(random.randrange(len(indices))))
        train_indices.extend(indices)

    test_indices = []
    m2 = ntest // nclass
    for c in classes:
        test_indices = test_indices + list(range(c * m2, (c + 1) * m2))

    splits = [(classes, train_indices),
              (classes, val_indices),
              (classes, test_indices)]
    torch.save(splits[0], 'data/split/stanford40-train')
    torch.save(splits[1], 'data/split/stanford40-val')
    torch.save(splits[2], 'data/split/stanford40-test')


def main():
    datadir = sys.argv[1]
    splits = os.path.join(datadir,"ImageSplits")
    partitions = {'train': [], 'test': []}
    for filename in os.listdir(splits):
        if filename.endswith('_test.txt'):
            partitions['test'].append(str(filename))
        if filename.endswith('_train.txt'):
            partitions['train'].append(str(filename))
    print('# of train class:', len(partitions['train']))
    print('# of test class:', len(partitions['test']))
    for split_type in ['train', 'test']:
        nimgs = 0
        for fname in partitions[split_type]:
            cls_name = fname.rsplit('_',1)[0]
            os.makedirs(os.path.join(datadir, split_type, cls_name))
            img_files = read(os.path.join(splits, fname))
            nimgs += len(img_files)
            for img_file in img_files:
                shutil.copy(os.path.join(datadir, 'JPEGImages', img_file.strip()),
                            os.path.join(datadir, split_type, cls_name, img_file.strip()))
            print('.', end='')
        print('# of ', split_type,' images:', nimgs)
    print('Done')


if __name__ == '__main__':
    # main()
    num_classes = 40
    num_train_images = 4000
    num_test_images = 4000
    split_data(num_classes, num_train_images, num_test_images)
