import os


def get_model(params, method=None):

    sa_params = {
        'latent_size':64,
        'num_slots':params.num_slots,
        'width':params.resolution[1],
        'height':params.resolution[0],
        'input_channels':params.in_channels,
        'eps':1e-8,
        'mlp_size':128,
        'attention_iters':params.num_iterations,
        'w_broadcast':params.resolution[1],
        'h_broadcast':params.resolution[0],
    }

    if params.arch.lower() == 'base':
        from models.slot_attention import SlotAttentionAE, ENCODER_PARAM, DECODER_PARAM
        if params.resolution[1] == 128 and params.resolution[0] == 128:
            from models.slot_attention.model_large import SlotAttentionAE, ENCODER_PARAM, DECODER_PARAM
            sa_params.update({
                'w_broadcast':8,
                'h_broadcast':8,
            })

        sa_params.update({
            'encoder_params':ENCODER_PARAM,
            'decoder_params':DECODER_PARAM,
        })

        model = SlotAttentionAE(**sa_params)  
        
    
    else:
        raise TypeError(f'Invalid architecture! [--arch]: {params.arch}')

    return model


def get_continual_model(params, method=None):

    model = get_model(params=params)

    continual_params = {
        'net': model,
        'num_task': params.num_task,
        'isolation': params.param_isolation,
        'isolation_parameters': params.isol_params,
    }

    if params.continual_arch.lower() == 'base':
        from models.continual.continual_model import ContinualOCL

        model = ContinualOCL(**continual_params)



    elif params.continual_arch.lower() == 'dpr':
        from models.continual.freeze_replay import FreezeRelpay
        

        continual_params.update({
            'replay_size': params.replay_size,
            'n_epochs': params.replay_epochs,
            'lr': 0.0004,
            'weight_decay': 0.0,

        })

        model = FreezeRelpay(**continual_params)


    elif params.continual_arch.lower() == 'pr':
        from models.continual.post_replay import PostRelpay
        

        continual_params.update({
            'replay_size': params.replay_size,
            'n_epochs': 50,
            'lr': 0.0004,
            'weight_decay': 0.0,

        })

        model = PostRelpay(**continual_params)


    elif params.continual_arch.lower() == 'drwt':
        from models.continual.freeze_replay import FreezeRelpay
        

        continual_params.update({
            'replay_size': params.replay_size,
            'n_epochs': 50,
            'lr': 0.0004,
            'weight_decay': 0.0,

        })

        model = FreezeRelpay(**continual_params)



    elif params.continual_arch.lower() == 'none':

        model = model

    else:
        raise TypeError(f'Invalid continual architecture! [--continual_arch]: {params.continual_arch}')
    
    return model

    