from torchvision.models import resnet18
from easydict import EasyDict
import torch

from fling.utils import Logger
from fling.utils.visualize_utils import plot_conv_kernels
from fling.utils.registry_utils import MODEL_REGISTRY

if __name__ == '__main__':
    # Step 1: prepare the model.
    model_arg=EasyDict(dict(
        name='resnet8',
        input_channel=3,
        class_number=100,
    ))
    model_name = model_arg.pop('name')
    model = MODEL_REGISTRY.build(model_name, **model_arg)
    path_head = './visualize/no_warm_part5'
    model.load_state_dict(torch.load(path_head + '/before_model.ckpt'))

    # Step 2: prepare the logger.
    logger = Logger(path_head + '/kernels')

    # Step 3: save the kernels.
    plot_conv_kernels(logger, model.pre_conv, name='pre-conv')