import torch

from .env_info import get_env_info
from .logger import create_logger
from .utils import save_config, set_seed, setup_cudnn
from .defaults import get_default_config
from .defaults_multi import get_default_config as get_default_config_multi
from .defaults_multi_loss import get_default_config as get_default_config_multi_loss
from .defaults_vit import get_default_config_vit
from .defaults_norm import get_default_config_norm
from .defaults_norm_moredata import get_default_config_moredata
from .defaults_view_sigmoid import get_default_config as get_default_config_view_sigmoid
from .defaults_view_sigmoid_aux import get_default_config as get_default_config_view_sigmoid_aux

from .defaults_norm_moredata_stage1 import get_default_config_moredata as get_default_config_moredata_stage1
from .defaults_view_sigmoid_aux_stage2 import get_default_config as get_default_config_view_sigmoid_aux_stage2
from .defaults_norm_moredata_aux_stage1 import get_default_config_moredata as get_default_config_moredata_aux_stage1


from .optimizer import build_optimizer, build_scheduler
from .dists import get_rank,world_info_from_env
from .losses import cross_entropy
from .factory import create_loss, create_loss_gather, create_multi_loss, SigLipLossTest

def update_config(config):
    # if config.dataset.name in ['CIFAR10', 'CIFAR100']:
    #     dataset_dir = f'~/.torch/datasets/{config.dataset.name}'
    #     config.dataset.dataset_dir = dataset_dir
    #     config.dataset.image_size = 32
    #     config.dataset.n_channels = 3
    #     config.dataset.n_classes = int(config.dataset.name[5:])
    # elif config.dataset.name in ['MNIST', 'FashionMNIST', 'KMNIST']:
    #     dataset_dir = '~/.torch/datasets'
    #     config.dataset.dataset_dir = dataset_dir
    #     config.dataset.image_size = 28
    #     config.dataset.n_channels = 1
    #     config.dataset.n_classes = 10

    if not torch.cuda.is_available():
        config.device = 'cpu'

    return config