import math
import torch
import socket
import argparse
import os
import numpy as np
import random

from dataloader.sprite import Sprite
import pickle

def load_dataset(opt):
    if opt.dataset == 'Sprite':
        import pickle
        data = pickle.load(open("../dataset/Sprite/data.pkl", "rb"))
        X_train, X_val, A_train, A_val = data['X_train'], data['X_val'], data['A_train'], data['A_val']
        D_train, D_val = data['D_train'], data['D_val']
        c_augs_train, c_augs_val = data['c_augs_train'], data['c_augs_val']
        m_augs_train, m_augs_val = data['m_augs_train'], data['m_augs_val']
        
        print("finish loading!")

        train_data = Sprite(train=True, data = X_train, A_label = A_train,
                            D_label = D_train, c_aug = c_augs_train, m_aug = m_augs_train)
        val_data = Sprite(train=False, data = X_val, A_label = A_val, 
                            D_label = D_val, c_aug = c_augs_val, m_aug = m_augs_val)
    else:
        raise ValueError('unknown dataset')

    return train_data, val_data


def clear_progressbar():
    print("\033[2A")
    print("\033[2K")
    print("\033[2A")

