import sys
import time

import torch
import torchvision
from torch.cuda.amp import GradScaler
from torchvision import transforms

sys.path.append(".")
sys.path.append("../..")
from datatools.prepare_data import prepare_data_loader
from datatools.const import IMAGENETNORMALIZE
from tools.misc import eval_network, prepare_experiment, get_general_args, train_network, get_cnn_network
from models.prompt import AdvProgramInterAdd

if __name__ == '__main__':
    p = get_general_args()

    # General settings
    p.add_argument('--network', choices=["resnet18", "resnet50", "resnet101", "instagram"], default='resnet18')

    # Settings for activation prompt
    p.add_argument('--prompt-type', '--pt', default='add', choices=["add", "ssf", "simple"])
    p.add_argument('--target-block', '--tb', type=int, default=[], nargs="+",
                   help="Meaningful value range: [-block_num, block_num - 1]")
    p.add_argument('--ap-init', default='zero', choices=['zero', 'randn'], help="initialization for visual prompt")

    # Hyper-parameters for training
    p.add_argument('--ap-lr', type=float, default=0.01, help="learning rate for visual prompt, used for ap and vp")
    p.add_argument('--lp-lr', type=float, default=0.01,
                   help="learning rate for linear probing, used for ap, vp, lp, and tn")
    p.add_argument('--wd', type=float, default=0, help="weight decay")

    args = p.parse_args()

    # Setup logger, get device, set seed
    logging, device = prepare_experiment(args)

    loaders, configs = prepare_data_loader(args.dataset, data_path=args.data_path, batch_size=128)

    network = get_cnn_network(args.network, device)
    network.requires_grad_(False)
    network.fc.out_features = len(configs['class_names'])
    network.fc.requires_grad_(True)
    logging.info(network)
    network.eval()
    network.to(device)

    normalize = transforms.Normalize(IMAGENETNORMALIZE['mean'], IMAGENETNORMALIZE['std']).to(device)

    input_size_collector = []


    def input_collector_hook(module, input):
        input_size_collector.append(input[0].shape)  # input is a tuple so we need to use input[0]


    def normalize_hook(module, input):
        return normalize(input[0])


    hook_handle_collector = []
    block_name_collector = []
    for name, m in network.named_modules():
        if len(hook_handle_collector) == 0 and isinstance(m, torch.nn.Conv2d):
            handle = m.register_forward_pre_hook(input_collector_hook)
            hook_handle_collector.append(handle)
            block_name_collector.append(name)
        if isinstance(m, torchvision.models.resnet.BasicBlock) or isinstance(m, torchvision.models.resnet.Bottleneck):
            handle = m.register_forward_pre_hook(input_collector_hook)
            hook_handle_collector.append(handle)
            block_name_collector.append(name)

    logging.info("The name of all the promptable layers:", block_name_collector)

    # One forward to let input_size_collector collect the feature map size
    network(torch.randn(16, 3, 224, 224).to(device))

    # Remove hooks
    for handle in hook_handle_collector:
        handle.remove()

    # Prepare Visual Prompt in Intermediate Layers
    for layer_num in args.target_block:
        if not layer_num < len(input_size_collector):
            raise ValueError(f"Not supported block number! Valid choice from [0, {len(input_size_collector)})]")

    prompt_list = []

    # Prepare for the visual prompt in the intermediate layers
    for layer_num in args.target_block:
        size = input_size_collector[layer_num]
        if args.prompt_type == "add":
            prompt = AdvProgramInterAdd(feature_map=torch.zeros((size[-2], size[-1])),
                                        feature_map_num=size[-3],
                                        init=args.ap_init).to(device)
        else:
            raise NotImplementedError
        prompt_list.append(prompt)
        logging.info(f"Activation Prompt in {layer_num} layer: {block_name_collector[layer_num]} is ready!")


    # Hook for visual prompt
    def make_ap_hook(idx):
        def ap_hook(module, input):
            input = list(input)
            input[0] = prompt_list[idx](input[0])
            input = tuple(input)
            return input

        return ap_hook


    # Add normalization hook before the first convolutional layer
    for name, m in network.named_modules():
        if isinstance(m, torch.nn.Conv2d):
            normalize_hook_handle = m.register_forward_pre_hook(normalize_hook)
            break

    # Install the hook to the target block
    vp_hook_handle_collector = []
    vp_hook_list = []
    for idx, layer_num in enumerate(args.target_block):
        for name, m in network.named_modules():
            if isinstance(m, torch.nn.Conv2d) and layer_num == -9:
                if name == block_name_collector[layer_num]:
                    vp_hook_handle = m.register_forward_pre_hook(make_ap_hook(idx))
                    vp_hook_handle_collector.append(vp_hook_handle)
            if isinstance(m, torchvision.models.resnet.BasicBlock) or isinstance(m,
                                                                                 torchvision.models.resnet.Bottleneck):
                if name == block_name_collector[layer_num]:
                    vp_hook_handle = m.register_forward_pre_hook(make_ap_hook(idx))
                    vp_hook_handle_collector.append(vp_hook_handle)

    # Dry-run forward path, make sure everything is correct.
    res = network(torch.randn(16, 3, 224, 224).to(device))

    # Optimizer
    optimizer_list = []
    # linear probing optimizer
    optimizer = torch.optim.AdamW(network.fc.parameters(), lr=args.lp_lr, weight_decay=args.wd)
    optimizer_list.append(optimizer)
    for prompt in prompt_list:
        optimizer = torch.optim.AdamW(prompt.parameters(), lr=args.ap_lr, weight_decay=args.wd)
        optimizer_list.append(optimizer)
    scheduler_list = []
    for optimizer in optimizer_list:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epoch)
        scheduler_list.append(scheduler)

    # Count tunable parameters through the optimizer parameters
    # trainable_params = calculate_trainable_param(optimizer_list)
    # logging.info(f"Total tunable parameters: {trainable_params}")
    # through_put = throughput(network, device=device)
    # logging.info(f"Throughput Evaluation: {through_put} images/s")

    # Train
    best_acc = 0.
    scaler = GradScaler()
    for epoch in range(args.epoch):
        network.eval()
        time_start = time.time()
        train_acc = train_network(network, loaders['train'], scaler, optimizer_list, scheduler_list, device, epoch,
                                  description="CNN AP Training")
        time_end = time.time()

        # Test
        if (epoch + 1) % args.test_interval == 0 and epoch >= args.test_start:
            acc = eval_network(network, loaders['test'], device, epoch)
            if acc > best_acc:
                best_acc = acc

            logging.info(
                f"For Epoch {epoch}, time {time_end - time_start:.2f}s, the training accuracy is {100 * train_acc: .2f}%, the test accuracy is {100 * acc: .2f}%, the best test acc is {100 * best_acc: .2f}%")
        else:
            logging.info(
                f"For Epoch {epoch}, time {time_end - time_start:.2f}s, the training accuracy is {100 * train_acc: .2f}%")