import torch.nn as nn
import torch
from torch.utils.data.dataloader import DataLoader

from .info_block import ModuleInfo, ModelInfo
from .pca_designer import PCADesigner



if __name__ == "__main__":
    model_info = ModelInfo(
        module_list=[
            ModuleInfo("1", nn.Identity),
            ModuleInfo("2", nn.Identity),
            ModuleInfo("3", nn.Identity),
        ]
    )

    for module_info in model_info:
        print(module_info)

def give_designer(config, model_info: ModelInfo, data_loader: DataLoader):
    
    return PCADesigner(
        model_info=model_info,
        data_loader=data_loader,
        recon_rate=config.recon_rate,
        device=torch.device(config.device)
    )