import os
import sys
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import SGD, Adam
import wandb
import numpy as np
import time
import argparse

from utils import D3Model
from ms_attacks import knockoff, jbda  # , knockoff_augment, noise, jbda, ising

from ml_common import set_seed, test, enable_gpu_benchmarking, get_device
from ml_models import get_model, model_choices
from ml_datasets import get_dataloaders, nclasses_dict, ds_choices


set_seed(2020)
enable_gpu_benchmarking()

bounds = [-1, 1]


def main():
    parser = argparse.ArgumentParser(description="attack ensemble")
    parser.add_argument(
        "--dataset_tar",
        type=str,
        default="mnist",
        help="target dataset",
        choices=ds_choices,
    )
    parser.add_argument(
        "--dataset_sur",
        type=str,
        default="fashion",
        help="surrogate dataset used to query the target",
        choices=ds_choices,
    )
    parser.add_argument("--batch_size", type=int, default=128, help="batch size")
    parser.add_argument(
        "--model_tar",
        type=str,
        default="conv3",
        choices=model_choices,
        help="Target model type",
    )

    parser.add_argument(
        "--model_hash",
        type=str,
        default="conv3",
        choices=model_choices,
        help="Hash model type",
    )

    parser.add_argument(
        "--model_sur",
        type=str,
        default="conv3",
        choices=model_choices,
        help="Surrogate/Clone model type",
    )
    parser.add_argument(
        "--opt", type=str, default="sgd", choices=["sgd", "adam"], help="Optimizer"
    )
    parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
    parser.add_argument("--epochs", type=int, default=10, help="number of epochs")

    parser.add_argument(
        "--n_models", type=int, default=1, help="number of models to train"
    )
    parser.add_argument(
        "--n_seed", type=int, default=100, help="number of seed examples"
    )
    parser.add_argument("--budget", type=int, default=50000, help="Query budget")
    parser.add_argument(
        "--aug_rounds", type=int, default=6, help="number of augmentation rounds"
    )
    parser.add_argument(
        "--exp_id", type=str, default="alpha", help="name of the experiment"
    )
    parser.add_argument(
        "--id_hash", type=str, default="32", help="name of the hash experiment"
    )
    parser.add_argument(
        "--quantize", action="store_true", help="Enable input quantization"
    )
    parser.add_argument(
        "--pretrained", action="store_true", help="use pretrained model"
    )
    parser.add_argument("--disable_pbar", action="store_true", help="Disable pbar")
    parser.add_argument("--augment", action="store_true", help="Use Pretrained model")
    parser.add_argument(
        "--attack",
        type=str,
        default="knockoff",
        help="attack",
        choices=["knockoff", "jbda", "jbda-tr"],
    )
    parser.add_argument(
        "--n_adaptive_queries", type=int, default=5, help="number of adaptive queries"
    )
    parser.add_argument(
        "--hash_mode",
        type=str,
        default="dnn",
        help="hash mode",
        choices=["phash", "dnn", "sha1"],
    )
    parser.add_argument(
        "--pred_type",
        type=str,
        default="soft",
        help="specify if the target model outputs hard/soft label predictions",
        choices=["hard", "soft"],
    )
    parser.add_argument(
        "--lamb", type=float, default=0.0, help="jbda lambda",
    )
    parser.add_argument(
        "--adaptive_mode",
        type=str,
        default="none",
        help="Adaptive attack mode",
        choices=["none", "normal", "ideal_attack", "ideal_defense", "normal_sim"],
    )

    args = parser.parse_args()
    path_exp = f"./exp/{args.dataset_tar}/{args.exp_id}/"

    device = get_device()
    dataloader_train, dataloader_test = get_dataloaders(
        args.dataset_tar, args.batch_size, augment=args.augment
    )

    T_list = []
    for i in range(args.n_models):
        T_path = path_exp + f"/T{i}.pt"
        T = get_model(args.model_tar, args.dataset_tar).to(device)
        T.load_state_dict(torch.load(T_path))
        T_list.append(T)

    H = None
    if args.hash_mode == "dnn":
        path_hash = f"./exp/hash/{args.id_hash}.pt"
        H = get_model(args.model_hash, args.dataset_tar, n_classes=10)
        H.load_state_dict(torch.load(path_hash))

    T_d3 = D3Model(
        T_list,
        bounds=bounds,
        num_classes=nclasses_dict[args.dataset_tar],
        quantization=args.quantize,
        model_hash=H,
        hash_mode=args.hash_mode,
    )
    acc_tar = 0
    for T in T_list:
        acc_tar += test(T, dataloader_test)
    acc_tar = acc_tar / len(T_list)
    print("* Loaded Target Model *")
    print("Target Accuracy: {:.2f}%\n".format(100 * acc_tar))

    S = get_model(args.model_sur, args.dataset_tar, pretrained=args.pretrained).to(
        device
    )  # Clone  (Student)

    if args.opt == "sgd":
        opt = SGD(S.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        sch = CosineAnnealingLR(opt, args.epochs, last_epoch=-1)
    elif args.opt == "adam":
        opt = Adam(S.parameters(), lr=args.lr)
        sch = None

    dataloader_sur, _ = get_dataloaders(
        args.dataset_sur, batch_size=args.batch_size, augment=args.augment
    )

    if args.attack == "knockoff":
        knockoff(
            T_d3,
            S,
            dataloader_sur,
            dataloader_test,
            opt,
            sch,
            acc_tar,
            args.batch_size,
            args.epochs,
            args.disable_pbar,
            args.budget,
            pred_type=args.pred_type,
            adaptive_mode=args.adaptive_mode,
            n_adaptive_queries=args.n_adaptive_queries,
        )
    elif args.attack in ["jbda", "jbda-tr"]:
        jbda(
            T_d3,
            S,
            dataloader_train,
            dataloader_test,
            opt,
            acc_tar,
            num_seed=args.n_seed,
            aug_rounds=args.aug_rounds,
            epochs=args.epochs,
            batch_size=args.batch_size,
            dataset=args.dataset_tar,
            bounds=bounds,
            mode=args.attack,
            lmbda=args.lamb,
            pred_type=args.pred_type,
        )
    else:
        sys.exit("Unknown Attack {}".format(args.attack))

    savedir_clone = path_exp + "clone/"
    if not os.path.exists(savedir_clone):
        os.makedirs(savedir_clone)

    torch.save(S.state_dict(), savedir_clone + "{}.pt".format(args.attack))
    print("* Saved Sur model * ")


if __name__ == "__main__":
    start = time.time()
    main()
    end = time.time()
    runtime = end - start
    print("Runtime: {:.2f} s".format(runtime))
