from .comlib import *
from .. import task 
from .model_setup import ModelSetup

@util.repr_alias(attr_name=False)   
@dataclass
class DnnArg():
    name:Literal["default","embed","layerNorm"]
    layer_num : int
    hidden_num : int
    def get(self,t,device,random_seed):
        if isinstance(t,task.MnistTask):
            model=task.simple_model.create_model(task=t,**asdict(self))
            modelSetup=ModelSetup(device,model,task.simple_model.Initialize(random_seed))
        return modelSetup
    

@util.repr_alias(attr_name=False)   
@dataclass
class VggArg():
    name:Literal["default","embed"]
    layer_num : int
    min_hidden_num : int
    max_hidden_num: int
    def get(self,t,device,random_seed):
        if isinstance(t,task.CifarTask):
            model=task.cnn.create_model(task=t,**asdict(self))
            modelSetup=ModelSetup(device,model,task.simple_model.Initialize(random_seed))
        return modelSetup