import pickle
import os
import random

from pathlib import Path

root_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..', 'dataset')
train_data_file = 'train_data.pkl'
test_data_file = 'test_data.pkl'

def save_transformed_data(save_path, train_tasks, test_tasks, args):
    task_name = args.raw_data + "_" + args.transform
    num = 0
    task_path = lambda n : os.path.join(save_path, task_name, f"numtask_{args.num_tasks}_seed_{args.seed}_{n}")

    if os.path.exists(task_path(num)):
        while os.path.exists(task_path(num)):
            num += 1
    print(task_path(num))

    os.makedirs(task_path(num))
    
    with open(os.path.join(task_path(num), train_data_file), 'wb') as f:
        pickle.dump(train_tasks, f)
    with open(os.path.join(task_path(num), test_data_file), 'wb') as f:
        pickle.dump(test_tasks, f)

def exists_continuum_data(root_path, raw_data, transform, num_tasks, seed):
    task_name = raw_data + "_" + transform
    
    task_path = os.path.abspath(os.path.join(root_path, task_name))
    
    data_info = f"numtask_{num_tasks}_seed_{seed}"
    
    try:
        return bool(sum([data_info in d for d in os.listdir(task_path)]))
    except:
        return False
        
def get_continuum_data(root_path, raw_data, transform, num_tasks, seed, index=None):
    task_name = raw_data + "_" + transform
    task_path = os.path.abspath(os.path.join(root_path, task_name))
    data_info = f"numtask_{num_tasks}_seed_{seed}"
    data_path_list = Path(task_path)
    target_data = [d for d in data_path_list.iterdir() if data_info in str(d)]
    
    if index is None:
        data = random.choice(target_data)
    else:
        pass

    return data

def decode_data(data_path, data_type):
    if data_type == 'train':
        data_path = data_path / train_data_file
    elif data_type == 'test':
        data_path = data_path / test_data_file
    else:
        raise ValueError("Wrong type")

    with open(data_path, 'rb') as f:
        data = pickle.load(f)
    return data



