import numpy as np
import argparse
import time
import logging
from tqdm import tqdm

from models import *
from arc_param import *

import warnings
warnings.filterwarnings('ignore')


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
_seed_ = 2023
np.random.seed(_seed_)
random.seed(_seed_)
torch.manual_seed(_seed_)
torch.cuda.manual_seed(_seed_)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-T', type=int, default=4)
    parser.add_argument('-num_class', type=int, default=10)
    parser.add_argument('-max_epoch', type=int, default=300)
    parser.add_argument('-T_max', type=int, default=300)
    parser.add_argument('-batch_size', type=int, default=128)
    parser.add_argument('-size', type=int, default=16)
    parser.add_argument('-optimizer', type=str, default='adamw', help='[sgd, adamw]')
    parser.add_argument('-scheduler', type=str, default='cosine', help='[step, cosine]')
    parser.add_argument('-learning_rate', type=float, default=0.01)
    parser.add_argument('-weight_decay', type=float, default=0.02)
    parser.add_argument('-num_workers', type=int, default=4)
    parser.add_argument('-if_static', action='store_false', default=True)
    parser.add_argument('-dataset_name', type=str, default='CIFAR10')
    parser.add_argument('-log_dir_prefix', type=str, default='./checkpoints/test')
    parser.add_argument('-dir_name', type=str, default='res18bc_cifar10_t4')

    args = parser.parse_args()

    T = args.T
    num_class = args.num_class
    max_epoch = args.max_epoch
    T_max = args.T_max
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    weight_decay = args.weight_decay
    num_workers = args.num_workers
    if_static = args.if_static

    dataset_name = args.dataset_name
    dataset_dir = '/data_smr/dataset/' + dataset_name  # dataset
    log_dir_prefix = args.log_dir_prefix
    dir_name = args.dir_name                           # results

    # path setting
    log_dir = os.path.join(log_dir_prefix, dir_name)
    pt_dir = os.path.join(log_dir_prefix, 'pt_' + dir_name)
    print(f'log location: {log_dir}')
    print(f'model.pt save location: {pt_dir}')
    if not os.path.exists(pt_dir):
        os.mkdir(pt_dir)
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    # logging setting
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(log_dir, 'log_{}.txt').format(time.strftime("%Y%m%d-%H%M%S")))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    # logging args
    for arg, val in args.__dict__.items():
        logging.info(arg + '.' * (100 - len(arg) - len(str(val))) + str(val))

    # dataloader
    train_dataset, test_dataset = get_transforms(dataset_name)
    train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    num_workers=num_workers,
                                                    drop_last=False,
                                                    pin_memory=True)
    test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                   batch_size=4 * batch_size,
                                                   shuffle=False,
                                                   num_workers=num_workers,
                                                   drop_last=False,
                                                   pin_memory=True)

    # network
    fb_mat = fb3
    net = ResNet18_bw_search(args=args, fb_mat=fb_mat)
    net = net.cuda()

    # optimizer(AdamW) and scheduler(Cos)
    optimizer = None
    if args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    else:
        print("will be added...")
        exit()

    scheduler = None
    if args.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
    elif args.scheduler == 'step':
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(max_epoch * 0.5), int(max_epoch * 0.75)], gamma=0.1)
        # for ann below| 200 epochs enough
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=[60, 120, 160, 200, ],
                                                         gamma=0.1)
    else:
        print("will be added...")
        exit()

    criterion = nn.CrossEntropyLoss().cuda()  # no label_smoothing

    # network training and testing
    logging.info('Start Training!')
    best_acc = 0
    for epoch in range(1, max_epoch+1):

        lr = scheduler.get_lr()[0]
        logging.info('epoch:{}\t lr:{}\t'.format(epoch, lr))
        start_time = time.time()
        train_loss = 0.0
        correct_train_sum = 0.0
        train_sum = 0.0

        # training
        net.train()
        with tqdm(train_data_loader, desc='training process:') as pbar:
            for img, label in pbar:
                img = img.cuda()
                label = label.cuda()
                optimizer.zero_grad()
                output, outputs = net(img)
                loss = TET_loss(outputs, label, criterion, )
                train_loss += loss.item()
                train_sum += label.numel()
                correct_train_sum += (output.argmax(dim=1) == label).float().sum().item()
                loss.backward()
                optimizer.step()
            train_accuracy = correct_train_sum / train_sum
            logging.info('[epoch: {}/{}]\t batch_avg_train_loss={:.6f}\t train_accuracy={:.6f}\n'
                        .format(epoch, max_epoch, train_loss / len(train_data_loader), train_accuracy))
            torch.save({
                'epoch': epoch + 1,
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }, os.path.join(pt_dir, 'newest_check_point.pt'))
            scheduler.step()

        # model test
        net.eval()
        test_sum = 0
        correct_sum = 0
        with torch.no_grad():
            with tqdm(test_data_loader, desc='testing process:') as pbar:
                for img, label in pbar:
                    img = img.cuda()
                    label = label.cuda()
                    output, _ = net(img)
                    correct_sum += (output.argmax(dim=1) == label).float().sum().item()
                    test_sum += label.numel()
                test_accuracy = correct_sum / test_sum
                if best_acc <= test_accuracy:
                    best_acc = test_accuracy
                    best_acc_epoch = epoch
                    torch.save({
                        'net': net.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'best_acc': best_acc,
                        'best_acc_epoch': best_acc_epoch
                    }, os.path.join(pt_dir, 'best_model.pt'))
                logging.info('[epoch: {}/{}]\t test_accuracy={:.6f}\t best_acc={:.6f}(at epoch {})'
                             .format(epoch, max_epoch, test_accuracy, best_acc, best_acc_epoch))
        speed_per_epoch = time.time() - start_time
        logging.info('speed per epoch:{}'.format(speed_per_epoch))
