import torch

from models.inr.metasiren import MetaSiren, MetaReLU
from models.wrapper import MetaWrapper

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_inr(P):
    if P.decoder == 'siren':
        model = MetaSiren(P.dim_in, P.dim_hidden, P.dim_out, P.num_layers,
                            w0=P.w0, w0_initial=P.w0, data_size=P.data_size, data_type=P.data_type, w0_type=P.w0_type)
    elif P.decoder == 'relu':
        model = MetaReLU(P.dim_in, P.dim_hidden, P.dim_out, P.num_layers,
                            w0=P.w0, w0_initial=P.w0, data_size=P.data_size, data_type=P.data_type, w0_type=P.w0_type)
    else:
        raise ValueError("no such model exists, mate.")

    return model


def get_model(P):
    decoder = get_inr(P)

    if P.data_type in ['img']:
        return MetaWrapper(P, decoder)
    elif P.data_type in ['video']:
        return MetaWrapper(P, decoder)
    else:
        raise NotImplementedError()
