import os
import argparse
import torch
from torchvision import transforms
from config import cfg
from dataset import make_dataset, make_data_loader, process_dataset, Compose
from module import save, Stats, makedir_exist_ok, process_control

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='data_name')
    parser.add_argument('--data_name', default='All', type=str)
    args = parser.parse_args()

    dim = {'data': 1, 'target': 1}
    data_names = ['SimulateR', 'SimulateC', 'Adult', 'SimulateCM', 'AdultM']
    data_names_keys = {'data': ['SimulateR', 'SimulateC', 'Adult', 'SimulateCM', 'AdultM'], 
                       'target': ['SimulateR']}
    cfg['tag'] = 'make_dataset'
    process_control()
    
    if args.data_name != 'All':
        data_names = [args.data_name]
    else:
        data_names = ['SimulateR', 'SimulateC', 'Adult', 'SimulateCM', 'AdultM']

    with torch.no_grad():
        for data_name in data_names:
            seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
            for seed in seeds:
                cfg['seed'] = seed
                stats_path = os.path.join('data', data_name, 'processed', 'seed_' + str(seed))
                dataset = make_dataset(data_name)
                dataset['train'].transform = Compose([transforms.ToTensor()])
                process_dataset(dataset)
                cfg['step'] = 0
                data_loader = make_data_loader(dataset, cfg[cfg['tag']]['optimizer']['batch_size'], shuffle=False)
                stats = {}
                for key in data_names_keys:
                    if data_name in data_names_keys[key]:
                        stats[key] = Stats(dim=dim[key])
                        for i, input in enumerate(data_loader['train']):
                            stats[key].update(input[key])
                print(data_name, stats)
                makedir_exist_ok(stats_path)
                save(stats, os.path.join(stats_path, 'stats'))
