import os, sys, shutil
import torch
import random
"""
Usage:
  python mit67.py /data/mit67
"""

def read(filename):
    with open(filename) as f:
        return f.readlines()

def split_data(nclass,ntrain,ntest):
    n = random.randint(nclass // 4, nclass)  # sample a subset of classes
    # classes = list(random.sample(list(range(nclass)), n))
    classes = list(range(nclass))
    val_indices = []
    train_indices = []

    m = ntrain // nclass
    for c in classes:
        indices = list(range(c * m, (c + 1) * m))
        for _ in range(m // 10):
            val_indices.append(indices.pop(random.randrange(len(indices))))
        train_indices.extend(indices)

    test_indices = []
    m2 = ntest // nclass
    for c in classes:
        test_indices = test_indices + list(range(c * m2, (c + 1) * m2))

    splits = [(classes, train_indices),
              (classes, val_indices),
              (classes, test_indices)]
    torch.save(splits[0], 'data/split/mit67-train')
    torch.save(splits[1], 'data/split/mit67-val')
    torch.save(splits[2], 'data/split/mit67-test')


def main():
    datadir = sys.argv[1]
    partitions = {'train': [], 'test': []}
    partitions['train'] = read(os.path.join(datadir, 'TrainImages.txt'))
    partitions['test'] = read(os.path.join(datadir, 'TestImages.txt'))
    print('# of training images:', len(partitions['train']))
    print('# of test images:', len(partitions['test']))

    for split_type in ['train', 'test']:
        classes = []
        for fname in partitions[split_type]:
            class_name,file_name = fname.strip().split('/')
            classes.append(class_name)
        classes = set(classes)
        for cls_name in classes:
            os.makedirs(os.path.join(datadir, split_type, cls_name))
        ct = 0
        for fname in partitions[split_type]:
            cls_name,file_name = fname.strip().split('/')
            shutil.copy(os.path.join(datadir, 'Images', cls_name, file_name),
                        os.path.join(datadir, split_type, cls_name, file_name))
            ct += 1
            if ct % 500 == 0:
                print('.', end='')
        print('# of ', split_type,' images:', len(partitions[split_type]))
    print('Done')


if __name__ == '__main__':
    # main()
    num_classes =67
    num_train_images = 67*80
    num_test_images = 67*20
    split_data(num_classes,num_train_images,num_test_images)
