import os
import yaml

from models import utils


def load_task_config(params):

    if params.dataset.lower() == 'continual_clevr':
        return load_task_config_clevr(params)
    

    path = params.task_config

    if os.path.exists(path):
        
        with open(path, 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        params = check_task_config(params, config)

        if utils.is_main_process():
            print(f'\nLoading task configs from {path}...')   
            print(config, '\n')
    else:
        if utils.is_main_process():
            print(f'Using manual task configs...')
        config = None
    return params, config


def check_task_config(params, config):
    
    if params.num_epochs == 0:
        num_epochs = []
    else:
        num_epochs = params.num_epochs

    assert 'num_task' in config
    assert 'width' in config
    assert 'height' in config
    assert 'num_background_objects' in config
    assert 'max_num_objects' in config
    assert 'input_channels' in config
    assert 'shape' in config
    assert 'color' in config
    assert 'num_shape' in config
    assert 'num_color' in config
    assert config['shape'] in ['pre_defined', 'random', 'user']
    assert config['color'] in ['pre_defined', 'random', 'user']

    total_colors = 0

    for i in range(config['num_task']):
        assert i in config
        assert 'dataset' in config[i]
        assert '_target_' in config[i]['dataset']
        assert '_function_' in config[i]['dataset']
        assert 'data_sizes' in config[i]['dataset']
        assert 'shape' in config[i]['dataset']
        assert 'color' in config[i]['dataset']
        if config['shape'] != 'user':
            assert len(config[i]['dataset']['shape']) == 3
        else:
            assert len(config[i]['dataset']['shape']) == 2
        if config['color'] != 'user':
            assert len(config[i]['dataset']['color']) == 4

        assert 'epochs' in config[i]
        if num_epochs != 0:
            num_epochs.append(int(config[i]['epochs']))

        total_colors += config[i]['dataset']['color'][0]

    params.total_colors = int(total_colors)
    params.num_task = config['num_task']
    params.num_epochs = num_epochs
    params.num_shape = config['num_shape']
    params.num_color = config['num_color']
    return params


def load_task_config_clevr(params):
    path = params.task_config

    if os.path.exists(path):
        
        with open(path, 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        params = check_task_config_clevr(params, config)

        if utils.is_main_process():
            print(f'\nLoading task configs from {path}...')   
            print(config, '\n')
    else:
        if utils.is_main_process():
            print(f'Using manual task configs...')
        config = None
    return params, config


def check_task_config_clevr(params, config):
    
    if params.num_epochs == 0:
        num_epochs = []
    else:
        num_epochs = params.num_epochs

    assert 'num_task' in config
    assert 'width' in config
    assert 'height' in config
    assert 'num_background_objects' in config
    assert 'max_num_objects' in config
    assert 'input_channels' in config


    for i in range(config['num_task']):
        i = f'task{i}'
        assert i in config
        assert 'dataset' in config[i]
        assert '_target_' in config[i]['dataset']
        assert '_function_' in config[i]['dataset']
        assert 'data_sizes' in config[i]['dataset']

        assert 'epochs' in config[i]
        if num_epochs != 0:
            num_epochs.append(int(config[i]['epochs']))
    params.num_task = config['num_task']
    params.num_epochs = num_epochs
    return params