import sys
import time

import torch
import torchvision
from torch.cuda.amp import GradScaler
from torchvision import transforms

sys.path.append(".")
sys.path.append("../..")
from datatools.prepare_data import prepare_data_loader
from datatools.const import IMAGENETNORMALIZE
from tools.misc import eval_network, prepare_experiment, get_general_args, train_network, get_cnn_network, \
    calculate_trainable_param
from models.prompt import AdvProgramInterAdd, AdvProgramInterAddSimple
from analysis.statistics import throughput

if __name__ == '__main__':
    # Network
    network = get_cnn_network('resnet101', 'cpu')
    network.eval()

    # Prompting only takes place before the forward function of a "BasicBlock" or the first convolutional layer
    # (which then becomes the original visual prompting)
    # (not any Conv2d can be prompted, due to the existence of the residue connection)
    # Therefore, we need to specify the basic blocks and their specific feature map size first.
    input_size_collector = []


    # Pre-forward hook to collect the input size of a BasicBlock
    def input_collector_hook(module, input):
        input_size_collector.append(input[0].shape)  # input is a tuple so we need to use input[0]


    # We need to collect handle to remove the hooks in the future.
    # We need to know the exact size of the input of each BasicBlock,
    # so that we can determine the shape of the prompt.
    # Therefore, we need to apply this hooks.
    hook_handle_collector = []
    block_name_collector = []
    # Install the input_size hook to each BasicBlock of the model
    for name, m in network.named_modules():
        if isinstance(m, torchvision.models.resnet.BasicBlock) or isinstance(m, torchvision.models.resnet.Bottleneck):
            handle = m.register_forward_pre_hook(input_collector_hook)
            hook_handle_collector.append(handle)
            block_name_collector.append(name)

    # One forward to let input_size_collector collect the feature map size
    network(torch.randn(1, 3, 224, 224))

    # Remove hooks
    for handle in hook_handle_collector:
        handle.remove()

    ap_parameters = [3 * (224 * 224)]
    for layer, input_size in enumerate(input_size_collector):
        ap_parameters.append(input_size.numel())

    for layer in range(len(ap_parameters)):
        print(f"{ap_parameters[layer]}")

