from typing import Sequence
import os
import pickle

import cole as cl


def get_split_tiny_imagenet(task_labels: Sequence[Sequence[int]], joint=False, joint_test=False,
                            train_transform=None, test_transform=None, path=None):

    if path is None:
        path = f"../data/tiny-imagenet-200/"

    with open(os.path.join(path, 'train.pkl'), 'rb') as f:
        train_ds = pickle.load(f)

    with open(os.path.join(path, 'val.pkl'), 'rb') as f:
        val_ds = pickle.load(f)

    if joint:
        task_labels = [[label for task in task_labels for label in task]]

    train_x, train_y = train_ds['data'], train_ds['labels']
    test_x, test_y = val_ds['data'], val_ds['labels']

    train_ds, test_ds = [], []

    for labels in task_labels:
        train_label_idx = [y in labels for y in train_y]
        train_ds.append((train_x[train_label_idx], train_y[train_label_idx]))

        if not joint_test:
            test_label_idx = [y in labels for y in test_y]
            test_ds.append((test_x[test_label_idx], test_y[test_label_idx]))

    if joint_test:
        test_ds.append((test_x, test_y))

    train_ds = [cl.XYDataset(x[0], x[1], transform=train_transform) for x in train_ds]
    test_ds = [cl.XYDataset(x[0], x[1], transform=test_transform) for x in test_ds]

    return cl.DataSplit(train_ds, None, test_ds)
