import os
from utils import get_features

def ds_features(config, dataloader, model, split_name=None):
    path = None
    if split_name is not None:
        if not os.path.exists(f'{config.log_dir}features'):
            os.makedirs(f'{config.log_dir}features')
        
        if config.frac < 1.:
            id += f'_{config.frac}'
        path = f'{config.log_dir}features/{split_name}_{id}.pth.tar'

    return get_features(dataloader, model, path=path, recompute=config.get_features)

def quick_eval(self, config, datasets, model):
    from sklearn.preprocessing import StandardScaler
    train_set = datasets['train']['dataset']
    val_set = datasets[config.eval_set]['dataset']
    
    x_train, y_train, train_md = ds_features(datasets['train']['loader'], model)
    x_val, y_val, val_md = ds_features(self.val_dataloader, self.model)

    if config.eval_sc:
        sc = StandardScaler()
        x_train = sc.fit_transform(x_train)
        x_val = sc.transform(x_val)
    

