import torch
from abc import ABC, abstractmethod

class BaseModel(torch.nn.Module, ABC):
    def __init__(self):
        super(BaseModel, self).__init__()
    @abstractmethod
    def encode(self, x):
        pass
    @abstractmethod
    def forward(self, x):
        pass

class BottleneckLinear(BaseModel):
    def __init__(self, bottleneck_dim, input_dim, output_dim):
        super(BottleneckLinear, self).__init__()
        self.encoder = torch.nn.Linear(input_dim, bottleneck_dim, bias=False)
        # NOTE Takagi ridge models don't use bias term
        self.predictor = torch.nn.Linear(bottleneck_dim, output_dim, bias=False)
    def encode(self, x):
        return self.encoder(x)
    def forward(self, x):
        x = self.encode(x)
        y_pred = self.predictor(x)
        return y_pred

def setup_model(model_type, bottleneck_dim, inpt_dim, outpt_dim):
    if model_type == 'BottleneckLinear':
        return BottleneckLinear(bottleneck_dim, inpt_dim, outpt_dim)
    else:
        raise Exception('Define valid opt.bottleneck')
