"""
Generate the CUB FSL split. 100 seen, 50 val and 50 test
"""

import numpy as np
import random


def dump_class(class_list, file_name):
    with open(file_name, 'w') as f:
        f.write('\n'.join(sorted(class_list)))


class_tmp = np.loadtxt('F:\datasets\CUB\CUB_200_2011\classes.txt', dtype=str)
classes = class_tmp[:, 1].tolist()

seen_classes = []
val_classes = []
trainval_classes = []
test_classes = []

for idx, cls in enumerate(classes):
    if idx % 2 == 0:
        seen_classes.append(cls)
        trainval_classes.append(cls)
    if idx % 4 == 1:
        val_classes.append(cls)
        trainval_classes.append(cls)
    if idx % 4 == 3:
        test_classes.append(cls)

seen_classes = sorted(seen_classes)
val_classes = sorted(val_classes)
trainval_classes = sorted(trainval_classes)
test_classes = sorted(test_classes)


dump_class(seen_classes, 'train.txt')
dump_class(val_classes, 'val.txt')
dump_class(test_classes, 'test.txt')
dump_class(trainval_classes, 'trainval.txt')
