# -*- coding: utf-8 -*-
import argparse
import copy
import json
import os
import pickle
import pprint
import random
import sys
import time

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import attacks
from eval import evaluate, evaluate_auto_attack
from loss import AlignLoss
from models.load_model import load_model
from utils.config_utils import str2bool as s2b
from utils.config_utils import str2strlist as s2sl
from utils.data import get_data_simple
from utils.share_layers import share_layers
from utils.utils import Tee

############################################
################### Args ###################
############################################
formatter = argparse.ArgumentDefaultsHelpFormatter
parser = argparse.ArgumentParser(description="Training", formatter_class=formatter)
# seed
parser.add_argument("--seed", type=int, default=1, help="Random seed")

# Model configuration
parser.add_argument("--model", type=str, default="cifar_resnet18", help="Model arch")

# dataset
parser.add_argument(
    "--data_root", type=str, default="../data", help="Root directory of data."
)
parser.add_argument("--dataset", type=str, default="cifar10", help="Root dir of data.")
parser.add_argument("--num_classes", type=int, default=10, help="Number of classes.")

# Optimization options
parser.add_argument("--epochs", "-e", type=int, default=110)
parser.add_argument("--learning_rate", "-lr", type=float, default=0.1)
parser.add_argument("--batch_size", "-b", type=int, default=128, help="Batch size")
parser.add_argument("--test_bs", type=int, default=256)
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum.")
parser.add_argument("--decay", type=float, default=0.0005, help="Weight decay")

# adversarial attack configuration
parser.add_argument(
    "--epsilon", type=float, default=8.0 / 255, help="perturbation bound"
)
parser.add_argument("--num_steps", type=int, default=10, help="perturb number of steps")
parser.add_argument(
    "--step_size", type=float, default=2.0 / 255, help="perturb step size"
)

# share layers configuration between main and sub model. (Default: share all layers, except BN layers.)
parser.add_argument("--share_layer_name_list", type=s2sl, default=["all"])
parser.add_argument("--exclude_layer_name_list", type=s2sl, default=None)
parser.add_argument(
    "--separate_bn", type=s2b, default=True, help="Split bn between main and sub model"
)

# feature regularization loss configuration.
#   Default:
#       - align all ReLU outputs in the last "block" of network.
#       - align features with cosine similarity.
#       - use predictor MLP head (h), with hidden dim being 1/4 of feature dim.
#       - align_type: x->y (stop-grad(y))
parser.add_argument("--feat_align_loss_metric", type=str, default="cos-sim")
parser.add_argument(
    "--align_features_weight",
    type=float,
    default=30.0,
    help="Align features loss weight",
)
parser.add_argument(
    "--is_avg_pool", type=s2b, default=True, help="Apply avg_pool to feature map"
)
parser.add_argument(
    "--is_relu", type=s2b, default=True
)  # whether to apply relu to feature map
parser.add_argument(
    "--is_use_predictor", type=s2b, default=True, help="Use predictor MLP head"
)
parser.add_argument(
    "--pred_dim_ratio", type=float, default=0.25, help="Predictor's hidden dim (rel.)"
)
parser.add_argument(
    "--align_type", type=str, default="x->y"
)  # x->y applies stop-grad(y)
parser.add_argument(
    "--align_layers", type=s2sl, default=[]
)  # specify layer names to regularize
parser.add_argument("--align_layers_name", type=str, default=None)  # just for logging

# classification losses
# xx ... adv, x ... clean, y ... label
# f ... main model, g ... sub model
parser.add_argument("--ce_loss_fxx_y_weight", type=float, default=1.0)
parser.add_argument("--ce_loss_gx_y_weight", type=float, default=1.0)
parser.add_argument("--is_auto_balance_ce_loss", type=s2b, default=True)

# evaluation configuration
parser.add_argument("--is_eval_auto_attack", type=s2b, default=True)

# Checkpoints
parser.add_argument(
    "--save_root_dir", type=str, default="./ckpts", help="Folder to save checkpoints."
)
parser.add_argument(
    "--mark", type=str, default="ARAT", help="Save dir name (method name)"
)
parser.add_argument(
    "--save_interval", type=int, default=10, help="Epoch interval to save checkpoints."
)
parser.add_argument(
    "--eval_interval", type=int, default=5, help="Epoch interval to evaluate."
)

# Experiment configuration
parser.add_argument(
    "--log_dir", type=str, default="./logs", help="Folder to save logs."
)
parser.add_argument("--overwrite", type=s2b, default=False)

args = parser.parse_args()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device: {}".format(device))

if not torch.cuda.is_available():
    print("CUDA not available. Exit.")
    exit()

# set random seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


############################################
################ Set Args ##################
############################################
if args.align_layers_name is None:
    args.align_layers_name = "-".join(args.align_layers)


############################################
################ Set Loggers ###############
############################################
SAVE_NAME = args.mark
SAVE_DIR = os.path.join(args.save_root_dir, args.dataset, args.model, SAVE_NAME)

is_done_path = os.path.join(SAVE_DIR, "done")
if not args.overwrite and os.path.exists(is_done_path):
    print("Save directory already exists:", SAVE_DIR)
    exit()
os.makedirs(SAVE_DIR, exist_ok=True)

# writer
writer_path = os.path.join(args.log_dir, SAVE_DIR)
writer = SummaryWriter(writer_path)

# df path
df_path = os.path.join(SAVE_DIR, SAVE_NAME + ".csv")
if not os.path.exists(df_path):
    columns = [
        "epoch",
        "time(s)",
        "train_loss",
        "test_loss",
        "test_acc",
        "adv_acc",
        "test_acc_sub",
    ]
    with open(df_path, "w") as f:
        f.write(",".join(columns) + "\n")

# Save args with json
with open(os.path.join(SAVE_DIR, "args.json"), "w") as f:
    json.dump(vars(args), f, indent=4)

# logger
sys.stdout = Tee(os.path.join(SAVE_DIR, "out.txt"))
sys.stderr = Tee(os.path.join(SAVE_DIR, "err.txt"))

print("SAVE_NAME: ", SAVE_NAME)
print("SAVE_DIR: ", SAVE_DIR)
state = {k: v for k, v in args._get_kwargs()}
print(state)


############################################
################### Data ###################
############################################
train_loader, test_loader, num_classes = get_data_simple(args)
args.num_classes = num_classes


############################################
############# Create Models ################
############################################
net = load_model(
    args,
    args.model,
    num_classes,
    extract_layers=args.align_layers,
    is_avg_pool=args.is_avg_pool,
    is_relu=args.is_relu,
)
net_sub = load_model(
    args,
    args.model,
    num_classes,
    extract_layers=args.align_layers,
    is_avg_pool=args.is_avg_pool,
    is_relu=args.is_relu,
)
net = net.to(device)
net_sub = net_sub.to(device)


############################################
############# Share parameters #############
############################################
# Here, the weight parameters of net_sub are shared with net.
# By default, ARAT shares all layers except BN layers.

# net = FeatureExtractor, net.model = DNN network
net.model, net_sub.model = share_layers(
    net.model,
    net_sub.model,
    args.share_layer_name_list,
    separate_bn=args.separate_bn,
    exclude_layer_name_list=args.exclude_layer_name_list,
)
net.set_hook_layers()
net_sub.set_hook_layers()


############################################
###### Regularization Loss function ########
############################################
# feature regularization loss
#  Default:
#   - align all ReLU outputs in the last "block" of network.
#   - align features with cosine similarity.
#   - use predictor MLP head (h), with hidden dim being 1/4 of feature dim.
#   - align_type: x->y (stop-grad(y))

feature_pair_loss_dict = {}
# get dimensions of features
with torch.no_grad():
    x = torch.randn(1, 3, 32, 32).to(device)
    _, feat_dict = net(x, get_feat=True)
    dims = {name: feat.shape[1] for name, feat in feat_dict.items()}
    print("dims: ", dims)
# create loss functions for each layer. Predictor h is created inside, which has parameters to be optimized.
for align_layers_name in args.align_layers:
    feat_dim = dims[align_layers_name]
    hidden_dim = int(feat_dim * args.pred_dim_ratio)
    print(
        "{} -> feat_dim: {}, hidden_dim: {}".format(
            align_layers_name, feat_dim, hidden_dim
        )
    )

    _feature_pair_loss = AlignLoss(
        args,
        loss_metric=args.feat_align_loss_metric,
        is_use_predictor=args.is_use_predictor,
        feat_dim=feat_dim,
        hidden_dim=hidden_dim,
        align_type=args.align_type,
    ).to(device)
    feature_pair_loss_dict[align_layers_name] = _feature_pair_loss


############################################
################ Optimizer #################
############################################
to_optimize_params = list(net.parameters())
print("Add subnet parameters to optimizer. ")
# Only add parameters from net_sub that are not already in net_params
for name, param_sub in net_sub.named_parameters():
    # print("name: ", name)
    if all(param_sub is not param for param in to_optimize_params):
        to_optimize_params.append(param_sub)
        print("added:", name, type(param_sub))

if args.is_use_predictor:
    print("Add predictor parameters to optimizer. ")
    for k in feature_pair_loss_dict:
        to_optimize_params += list(feature_pair_loss_dict[k].parameters())

# optimizer
optimizer = torch.optim.SGD(
    to_optimize_params,
    args.learning_rate,
    momentum=args.momentum,
    weight_decay=args.decay,
)


############################################
################# Scheduler ################
############################################
def get_lr(epoch):  # removed bug!!
    """epoch starts from 1."""
    init_lr = args.learning_rate
    if epoch <= 75:
        lr = init_lr
    elif 76 <= epoch <= 90:
        lr = init_lr * 0.1
    elif 91 <= epoch <= 100:
        lr = init_lr * 0.01
    elif 101 <= epoch <= 110:
        lr = init_lr * 0.001
    elif epoch >= 111:
        lr = init_lr * 0.0001
    return lr


class CustomLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, last_epoch=-1):
        """assuming that epoch starts from 1."""
        super(CustomLR, self).__init__(optimizer, last_epoch)

        for param_group in optimizer.param_groups:
            param_group["lr"] = 0.02

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch
        else:
            self.last_epoch = epoch

        lr = get_lr(epoch)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        print("scheduler step(): epoch={}, lr={}".format(epoch, lr))


scheduler = CustomLR(optimizer)

############################################
############### Adv. attack ################
############################################
adversary = attacks.PGD_linf(
    epsilon=args.epsilon, num_steps=args.num_steps, step_size=args.step_size
).cuda()


############################################
############# Training Function ############
############################################
def train(epoch, args=args, analysis_meter=None, prev_train_acc=None):
    train_info_dict = {}

    net.train()
    net_sub.train()

    loss_total = 0.0
    sample_num = 0
    for batch_idx, (bx, by) in tqdm(enumerate(train_loader)):
        bx, by = bx.cuda(), by.cuda()
        sample_num += bx.size(0)

        loss_dict = {}
        with torch.cuda.amp.autocast():
            #########################
            ######## attack #########
            #########################
            adv_bx = adversary(net, bx, by)

            #########################
            ######## forward ########
            #########################
            # Get adv outputs from main net & clean outputs from sub net
            logits_fxx, feat_dict_fxx = net(adv_bx, get_feat=True)
            logits_gx, feat_dict_gx = net_sub(bx, get_feat=True)

            ######################
            ######## loss ########
            ######################
            # classification loss (main net)
            # Auto-balance: the more (clean) accurate already, the more weight
            w = (
                prev_train_acc
                if args.is_auto_balance_ce_loss
                else args.ce_loss_fxx_y_weight
            )
            loss_dict["ce_loss_fxx_y"] = F.cross_entropy(logits_fxx, by) * w

            # classification loss (sub net)
            # Auto-balance: the less (clean) accurate already, the more weight
            w = (
                1 - prev_train_acc
                if args.is_auto_balance_ce_loss
                else args.ce_loss_gx_y_weight
            )
            loss_dict["ce_loss_gx_y"] = F.cross_entropy(logits_gx, by) * w

            # feature regularization loss
            for i, layer_name in enumerate(args.align_layers):
                feat_fxx = feat_dict_fxx[layer_name]
                feat_gx = feat_dict_gx[layer_name]

                w = args.align_features_weight / len(args.align_layers)
                reg_loss = feature_pair_loss_dict[layer_name](feat_fxx, feat_gx) * w
                loss_dict[f"feat_loss_fxx_gx-{layer_name}"] = reg_loss

        # total loss
        loss = sum(loss_dict.values())
        if torch.isnan(loss):
            print("nan loss, force exit.")
            exit()

        loss_total += float(loss.data)

        #########################
        ######## update ########
        #########################
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log
        if batch_idx == 0:
            pprint.pprint(loss_dict)

    state["train_loss"] = loss_total / len(train_loader)
    train_info_dict["train_loss"] = loss_total / len(train_loader)
    return train_info_dict


############################################
############### Test Function ##############
############################################
def test(loader):
    net.eval()
    net_sub.eval()

    loss_total = 0.0
    correct = 0
    correct_sub = 0
    with torch.no_grad():
        for bx, by in loader:
            bx, by = bx.cuda(), by.cuda()

            with torch.cuda.amp.autocast():
                logits = net(bx)
                logits_sub = net_sub(bx)

            loss = F.cross_entropy(logits, by)

            # accuracy
            pred = logits.data.max(1)[1]
            correct += pred.eq(by.data).sum().item()
            pred_sub = logits_sub.data.max(1)[1]
            correct_sub += pred_sub.eq(by.data).sum().item()

            # test loss average
            loss_total += float(loss.data)
    n = len(loader.dataset)
    return loss_total / len(loader), correct / n, correct_sub / n


def test_single(loader, net, max_n=1000):
    net.eval()
    loss_avg = 0.0
    correct = 0
    n = 0
    with torch.no_grad():
        for bx, by in loader:
            bx, by = bx.cuda(), by.cuda()

            logits = net(bx)
            loss = F.cross_entropy(logits, by)

            # accuracy
            pred = logits.data.max(1)[1]
            correct += pred.eq(by.data).sum().item()

            # test loss average
            loss_avg += float(loss.data)

            n += bx.size(0)
            if n >= max_n:
                break
    return loss_avg / n, correct / n


############################################
############## Training Loop ###############
############################################
st = time.time()
print("Beginning Training\n")

start_epoch = 1
adv_acc = -1
train_info_dict_list = []
for epoch in range(start_epoch, args.epochs + 1):
    state["epoch"] = epoch

    begin_epoch = time.time()

    ###################
    #### Train #####
    ###################
    _, prev_train_acc = test_single(train_loader, net_sub)
    print("prev_train_acc", prev_train_acc)

    train_info_dict = train(epoch, prev_train_acc=prev_train_acc)
    train_info_dict_list.append(train_info_dict)

    train_loss = train_info_dict["train_loss"]

    test_loss, test_acc, test_acc_sub = test(test_loader)
    state["test_loss"] = test_loss
    state["test_accuracy"] = test_acc
    state["test_acc_sub"] = test_acc_sub

    ###################
    #### Scheduler ####
    ###################
    scheduler.step(epoch)

    ###################
    #### Save model ###
    ###################
    model_save_path = os.path.join(SAVE_DIR, f"epoch_{epoch}.pt")
    optim_save_path = os.path.join(SAVE_DIR, f"optim_{epoch}.pt")
    if epoch % args.save_interval == 0:
        # save model (remove featExtract wrapper)
        torch.save(net.model.state_dict(), model_save_path)
        torch.save(
            net_sub.model.state_dict(), model_save_path.replace(".pt", "_sub.pt")
        )

        optims = {
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
        }
        torch.save(optims, optim_save_path)  # remove norm wrapper

    ###################
    #### Evaluation ###
    ###################
    if epoch % args.eval_interval == 0:
        _, adv_acc = evaluate(args, net, test_loader)
        if epoch % 5 == 0:
            _, adv_acc_sub = evaluate(args, net_sub, test_loader)
            writer.add_scalar("adv_acc_sub", adv_acc_sub, epoch)
            print("adv_acc_sub:", adv_acc_sub)
    else:
        adv_acc = -1

    ###################
    #### Logging ######
    ###################
    log_data = (
        epoch,
        int(time.time() - begin_epoch),
        state["train_loss"],
        state["test_loss"],
        state["test_accuracy"] * 100.0,
        adv_acc * 100.0,
        state["test_acc_sub"] * 100.0,
    )
    with open(df_path, "a") as f:
        f.write(
            "%03d,%05d,%0.6f,%0.5f,%0.2f,%0.2f,%0.2f\n"
            % (
                log_data[0],
                log_data[1],
                log_data[2],
                log_data[3],
                log_data[4],
                log_data[5],
                log_data[6],
            )
        )
    print(
        "\nEpoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Acc {4:.2f} | Adv Acc {5:.2f} | Test Acc Sub {6:.2f}".format(
            log_data[0],
            log_data[1],
            log_data[2],
            log_data[3],
            log_data[4],
            log_data[5],
            log_data[6],
        )
    )

    writer.add_scalar("test_acc", test_acc, epoch)
    writer.add_scalar("adv_acc", adv_acc, epoch)
    writer.add_scalar("test_acc_sub", test_acc_sub, epoch)


##################################
######## Save Last Model #########
##################################
# Save model (remove featExtract wrapper)
model_save_path = os.path.join(SAVE_DIR, "model_last.pt")
optim_save_path = os.path.join(SAVE_DIR, "optim_last.pt")
torch.save(net.model.state_dict(), model_save_path)
torch.save(net_sub.model.state_dict(), model_save_path.replace(".pt", "_sub.pt"))

optims = {"optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict()}
torch.save(optims, optim_save_path)


##################################
######## Eval Auto Attack ########
##################################
if args.is_eval_auto_attack:
    _, adv_acc_AA = evaluate_auto_attack(args, net, test_loader)
    writer.add_scalar("adv_acc_AA", adv_acc_AA, epoch)
    print("adv_acc_AA:", adv_acc_AA)
else:
    adv_acc_AA = -1


##################################
##### Save Final Metrics #########
##################################
metrics = {
    "test_acc": test_acc,
    "adv_acc": adv_acc,
    "adv_acc_AA": adv_acc_AA,
    "test_acc_sub": test_acc_sub,
    "adv_acc_sub": adv_acc_sub,
}

args_dict = vars(args)
for k, v in args_dict.items():
    if type(v) == list:
        args_dict[k] = str(v)
print(args_dict)
writer.add_hparams(
    args_dict,
    metrics,
)
writer.close()

d = {
    "args": args_dict,
    "metrics": metrics,
}
with open(os.path.join(SAVE_DIR, "hparams_metrics.pkl"), "wb") as f:
    pickle.dump(d, f)

ed = time.time()
print("Total time: {} / min".format((ed - st) / 60))

# done
with open(is_done_path, "a") as f:
    f.write("done")
print("\nDone")
