import torch as t

class Penultimate():
    def __init__(self, model, break_layers):
        self.model = model
        self.break_layers = break_layers
        
    def __call__(self, x):
        for name, module in self.model._modules.items():
            if name == self.break_layers:
                break
            x = module(x)
        x = t.reshape(x,(x.size()[0],x.size()[1]))
        return x



def gen_pen(model,break_layer='fc'):
    if t.cuda.device_count()>1:
        return Penultimate(model=model.module,break_layers=break_layer)
    else:
        return Penultimate(model=model,break_layers=break_layer)
