from typing import List
from torch import nn



class Test_hydra:
    def __init__(self, input_dims, output_dims,
                 kernels: List[int],
                 length: int,
                 hidden_dims=64, depth=10,
                 dim_treatments = 10, dim_outcome = 10,
                 mask_mode='binomial'):
        self.input_dims = input_dims
