
import torch
import torch.nn as nn


class Learner(nn.Module): #input= batchsize (64)* 1*28*28

    class ResLinear(nn.Module):
        def __init__(self, hdim=32):
            super(Learner.ResLinear, self).__init__()
            self.hdim=hdim
            self.lin = nn.Linear(hdim,hdim)
            self.relu = nn.ReLU() 

        def forward(self, x):
            return x + self.relu(self.lin(x))

    
    def __init__(self, idim=16, hdim=32, odim=2, num_layers=1):
        super(Learner, self).__init__()
        self.idim=idim
        self.hdim=hdim
        self.odim=odim
        self.num_layers=num_layers
        self.layers = nn.Sequential()
        for i in range(self.num_layers+2):
            cur_idim = self.idim if i==0 else self.hdim
            cur_odim = self.odim if i==self.num_layers+1 else self.hdim
            if i==0:
                self.layers.add_module(f"linear_{i}",nn.Linear(cur_idim,cur_odim))
                self.layers.add_module(f"relu_{i}",nn.ReLU())
            elif i==self.num_layers+1:
                self.layers.add_module(f"linear_{i}",nn.Linear(cur_idim,cur_odim))
            else:
                self.layers.add_module(f"res_{i}",Learner.ResLinear(hdim))
        
    def forward(self, x):
        return self.layers(x)




        
  