from torchvision.models import resnet18
from easydict import EasyDict
import torch

from fling.utils.visualize_utils import ActivationMaximizer
from fling.utils.registry_utils import MODEL_REGISTRY


def am_demo(name, layer, channel_id):
    path_head = './visualize/' + name
    
    model_arg=EasyDict(dict(
        name='resnet8',
        input_channel=3,
        class_number=100,
    ))
    model_name = model_arg.pop('name')
    model = MODEL_REGISTRY.build(model_name, **model_arg)
    model.load_state_dict(torch.load(path_head + '/before_model.ckpt'))
    activation_maximizer = ActivationMaximizer(iteration=1000, working_dir=path_head+'/resnet8_am_1/'+layer, tv_weight=1)
    activation_maximizer.activation_maximization(model, layer, channel_id=channel_id,
                                                 image_shape=[3, 32, 32], device='cuda', learning_rate=1e-1)


layers = ['pre_conv', 'layers.0.0.conv1', 'layers.0.0.conv2', 'layers.1.0.conv1', 'layers.1.0.conv2', 
          'layers.1.0.downsample.0', 'layers.2.0.conv1', 'layers.2.0.conv2', 'layers.2.0.downsample.0', 'fc']


if __name__ == '__main__':
    # names = ['avg100', 'no_warm', 'warm5_part1', 'warm5_part5']
    names = ['no_warm_part5']
    # name = 'warm5_part1'
    for name in names:
        for layer in layers:
            for channel in range(64):
                am_demo(name, layer, channel)
    

# rnd2str = {
#     0: ['pre_conv', 'bn1'],  # 1.5
#     1: ['layers.0.0.conv1', 'layers.0.0.bn1'],  # 5.898
#     2: ['layers.0.0.conv2', 'layers.0.0.bn2'],  # 5.898
#     3: ['layers.1.0.conv1', 'layers.1.0.bn1'],  # 11.8
#     4: ['layers.1.0.conv2', 'layers.1.0.bn2'],  # 23.59
#     5: ['layers.1.0.downsample.0', 'layers.1.0.downsample.1'],  # 1.311
#     6: ['layers.2.0.conv1', 'layers.2.0.bn1'],  # 47.19
#     7: ['layers.2.0.conv2', 'layers.2.0.bn2'],  # 94.37
#     8: ['layers.2.0.downsample.0', 'layers.2.0.downsample.1'],  # 5.243
#     9: ['fc']  # 0.411
# }