import os
import sys
import logging

import torch
import numpy as np

from src.attacks import pgd_rand
from src.train import train_standard, train_soar
from src.evaluation import test_clean, test_adv
from src.args import get_args
from src.utils_dataset import load_dataset, load_svhn
# from src.utils_log import metaLogger, rotateCheckpoint
from src.utils_general import seed_everything, get_model, get_optim

def train(args, epoch, loader, model, opt, device):
    """perform one epoch of training."""
    if args.method == "standard":
        train_log = train_standard(epoch, loader, model, opt, device)

    elif args.method == "soar":
        train_log = train_soar(epoch, loader, model, args, opt, device)

    else:
        raise  NotImplementedError("Training method not implemented!")

    logging.info(
        "Epoch: [{0}]\t"
        "Loss: {loss:.6f}\t"
        "Accuracy: {acc:.2f}".format(
            epoch,
            loss=train_log[1],
            acc=train_log[0]))

    return train_log

def main():

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    attack_param = {"ord":np.inf, "epsilon": 8./255., "alpha":2./255., "num_iter": 20, "restart": 1}

    args = get_args()
    logging.basicConfig(
        filename=args.j_dir+ "/log/log.txt",
        format='%(asctime)s %(message)s', level=logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))

    seed_everything(args.seed)
    
    if args.dataset == "svhn":
        train_loader, test_loader = load_svhn(args.batch_size)
    else:
        train_loader, test_loader = load_dataset(args.batch_size)

    model = get_model(args, device)
    opt, lr_scheduler = get_optim(model, args)

    for _epoch in range(args.epoch):
        train_log = train(args, _epoch, train_loader, model, opt, device)

        test_log = test_clean(test_loader, model, device)
        logging.info(
            "Test set: Loss: {loss:.6f}\t"
            "Accuracy: {acc:.2f}".format(
                loss=test_log[1],
                acc=test_log[0]))

        adv_log = test_adv(test_loader, model, pgd_rand, attack_param, device)

        logging.info(
            "PGD20: Loss: {loss:.6f}\t"
            "Accuracy: {acc:.2f}".format(
                loss=adv_log[1],
                acc=adv_log[0]))

        if lr_scheduler:
            lr_scheduler.step()
            
    torch.save(model.state_dict(), args.j_dir+"/model/model.pt")

if __name__ == "__main__":
    main()
