import sys
import time

import timm
import torch
from torch.cuda.amp import GradScaler
from torchvision import transforms

sys.path.append(".")
sys.path.append("../..")
from datatools.prepare_data import prepare_data_loader
from datatools.const import IMAGENETNORMALIZE
from tools.misc import get_vit_network, eval_network, get_general_args, prepare_experiment, \
    calculate_trainable_param, train_network
from models.prompt import AdvProgramInterAdd, AdvProgramInterAddSimple
from analysis.statistics import throughput

if __name__ == '__main__':
    p = get_general_args()

    # General settings
    p.add_argument('--network', choices=["vit_tiny", "vit_small", "vit_base", "vit_large"], default='vit_tiny')

    # Settings for activation prompt
    p.add_argument('--prompt-type', '--pt', default='add', choices=["add", "vpt", "simple"])
    p.add_argument('--target-block', '--tb', type=int, default=[], nargs="+",
                   help="Meaningful value range: [-block_num, block_num - 1]")
    p.add_argument('--ap-init', default='zero', choices=['zero', 'randn'], help="initialization for visual prompt")

    # Hyper-parameters for training
    p.add_argument('--ap-lr', type=float, default=0.001, help="learning rate for visual prompt, used for ap and vp")
    p.add_argument('--lp-lr', type=float, default=0.001,
                   help="learning rate for linear probing, used for ap, vp, lp, and tn")
    p.add_argument('--wd', type=float, default=0, help="weight decay")

    # For ablation study
    p.add_argument('--per-class-num', '--pcn', type=int, default=None, help="Number of samples per class")

    args = p.parse_args()

    # Setup logger, get device, set seed
    logging, device = prepare_experiment(args)

    # data
    loaders, configs = prepare_data_loader(args.dataset, data_path=args.data_path, batch_size=args.batch_size,
                                           per_class_size=args.per_class_num)

    # Network
    network, dim = get_vit_network(args.network, device)
    logging.info(network)
    network.to(device)

    input_size_collector = []
    def input_collector_hook(module, input):
        # input is a tuple so we need to use input[0]
        input_size_collector.append(input[0].shape)

    # We need to collect handle to remove the hooks in the future.
    hook_handle_collector = []
    block_name_collector = []
    # Install the input_size hook to each BasicBlock of the model
    for name, m in network.named_modules():
        if isinstance(m, timm.models.vision_transformer.Block):
            handle = m.register_forward_pre_hook(input_collector_hook)
            hook_handle_collector.append(handle)
            block_name_collector.append(name)
    logging.info("The name of all the promptable layers:", block_name_collector)

    # One forward to let input_size_collector collect the feature map size
    network(torch.randn(16, 3, 224, 224).to(device))

    # Remove hooks
    for handle in hook_handle_collector:
        handle.remove()

    # Prepare Visual Prompt in Intermediate Layers
    for layer_num in args.target_block:
        if not layer_num < len(input_size_collector):
            raise ValueError(f"Not supported block number! Valid choice from [0, {len(input_size_collector)})]")

    prompt_list = []

    for layer_num in args.target_block:
        size = input_size_collector[layer_num]
        if args.prompt_type == "add":
            adv_program = AdvProgramInterAdd(feature_map=torch.zeros((size[-2], size[-1])),
                                             feature_map_num=0,
                                             init=args.ap_init).to(device)
        elif args.prompt_type == "simple":
            adv_program = AdvProgramInterAddSimple(feature_map=torch.zeros((size[-2], size[-1])),
                                                   feature_map_num=0,
                                                   init=args.ap_init).to(device)
        else:
            raise NotImplementedError
        prompt_list.append(adv_program)

    def make_ap_hook(idx):
        def ap_hook(module, input):
            input = list(input)
            input[0] = prompt_list[idx](input[0])
            input = tuple(input)
            return input
        return ap_hook

    # Add normalization hook before the first convolutional layer
    normalize = transforms.Normalize(IMAGENETNORMALIZE['mean'], IMAGENETNORMALIZE['std']).to(device)

    def normalize_hook(module, input):
        return normalize(input[0])

    for name, m in network.named_modules():
        if isinstance(m, torch.nn.Conv2d):
            normalize_hook_handle = m.register_forward_pre_hook(normalize_hook)
            break

    # Install vp hook to the target block
    ap_hook_handle_collector = []
    for idx, layer_num in enumerate(args.target_block):
        for name, m in network.named_modules():
            if isinstance(m, timm.models.vision_transformer.Block) and name == block_name_collector[layer_num]:
                ap_hook_handle = m.register_forward_pre_hook(make_ap_hook(idx))
                ap_hook_handle_collector.append(ap_hook_handle)
                logging.info(f"Activation Prompt in {layer_num} layer: {block_name_collector[layer_num]} is ready!")

    # Dry-run forward path, make sure everything is correct.
    res = network(torch.randn(16, 3, 224, 224).to(device))

    # Optimizer
    optimizer_list = []

    # prepare optimizers
    for param in network.parameters():
        param.requires_grad_(False)
    network.head.requires_grad_(True)
    lp_optimizer = torch.optim.Adam(network.head.parameters(), lr=args.lp_lr, weight_decay=args.wd)
    optimizer_list.append(lp_optimizer)
    for adv_program in prompt_list:
        ap_optimizer = torch.optim.Adam(adv_program.parameters(), lr=args.ap_lr, weight_decay=args.wd)
        optimizer_list.append(ap_optimizer)
    scheduler_list = []
    for optimizer in optimizer_list:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epoch)
        scheduler_list.append(scheduler)

    if len(args.target_block) > 0:
        logging.info("Activation prompt!")
    else:
        logging.info("Linear probing!")

    # Count tunable parameters through the optimizer parameters
    trainable_params = calculate_trainable_param(optimizer_list)
    logging.info(f"Total tunable parameters: {trainable_params}")
    through_put = throughput(network, device=device)
    logging.info(f"Throughput Evaluation: {through_put} images/s")

    # Train
    best_acc = 0.
    scaler = GradScaler()
    for epoch in range(args.epoch):
        time_start = time.time()
        network.eval()
        train_acc = train_network(network, loaders['train'], scaler, optimizer_list, scheduler_list, device, epoch,
                                  description="ViT AP Training")
        time_end = time.time()

        # Test
        if (epoch+1) % args.test_interval == 0 and epoch >= args.test_start:
            acc = eval_network(network, loaders['test'], device, epoch)
            if acc > best_acc:
                best_acc = acc

            logging.info(
                f"For Epoch {epoch}, time {time_end - time_start:.2f}s, the training accuracy is {100 * train_acc: .2f}%, the test accuracy is {100 * acc: .2f}%, the best test acc is {100 * best_acc: .2f}%")
        else:
            logging.info(
                f"For Epoch {epoch}, time {time_end - time_start:.2f}s, the training accuracy is {100 * train_acc: .2f}%")