import torch
import numpy as np
import torch.nn as nn
import os
import sys
import argparse
import time
import random
import logging
from tqdm import tqdm
from models import get_train_val_loaders
from models import TET_loss
from models import ResNet18_super

import warnings
warnings.filterwarnings('ignore')


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
_seed_ = 2023
np.random.seed(_seed_)
random.seed(_seed_)
torch.manual_seed(_seed_)
torch.cuda.manual_seed(_seed_)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-T', type=int, default=4)
    parser.add_argument('-num_class', type=int, default=10)
    parser.add_argument('-max_epoch', type=int, default=100)
    parser.add_argument('-T_max', type=int, default=100)
    parser.add_argument('-batch_size', type=int, default=128)
    parser.add_argument('-size', type=int, default=16)
    parser.add_argument('-optimizer', type=str, default='adamw', help='[sgd, adamw]')
    parser.add_argument('-scheduler', type=str, default='cosine', help='[step, cosine]')
    parser.add_argument('-learning_rate', type=float, default=0.01)
    parser.add_argument('-weight_decay', type=float, default=0.02)
    parser.add_argument('-num_workers', type=int, default=4)
    parser.add_argument('-if_static', action='store_false', default=True)

    # path of datasets and results
    parser.add_argument('-dataset_name', type=str, default='CIFAR10')
    parser.add_argument('-log_dir_prefix', type=str, default='./checkpoints/supernet')
    parser.add_argument('-dir_name', type=str, default='cifar10_t4')

    args = parser.parse_args()

    T = args.T
    num_class = args.num_class
    max_epoch = args.max_epoch
    T_max = args.T_max
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    weight_decay = args.weight_decay
    num_workers = args.num_workers
    if_static = args.if_static

    dataset_name = args.dataset_name
    dataset_dir = '/data_smr/dataset/' + dataset_name  # dataset
    log_dir_prefix = args.log_dir_prefix
    dir_name = args.dir_name                           # results

    # path setting
    log_dir = os.path.join(log_dir_prefix, dir_name)
    pt_dir = os.path.join(log_dir_prefix, 'pt_' + dir_name)
    print(f'log location: {log_dir}')
    print(f'model.pt save location: {pt_dir}')
    if not os.path.exists(pt_dir):
        os.mkdir(pt_dir)
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    # logging setting
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(log_dir, 'log_{}.txt').format(time.strftime("%Y%m%d-%H%M%S")))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    # logging args
    for arg, val in args.__dict__.items():
        logging.info(arg + '.' * (100 - len(arg) - len(str(val))) + str(val))

    # dataloader setting for cifar and cifardvs
    train_data_loader, _ = get_train_val_loaders(args, dataset_name, search=True)

    # network
    net = ResNet18_super(args=args)
    net = net.cuda()

    # optimizer(AdamW) and scheduler(Cos)
    optimizer = None
    if args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    else:
        print("will be added...")
        exit()

    scheduler = None
    if args.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
    elif args.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=[60, 120, 160, 200, ],
                                                         gamma=0.1)
    else:
        print("will be added...")
        exit()

    criterion = nn.CrossEntropyLoss().cuda()  # no label_smoothing

    # network training and testing
    logging.info('Start Training!')
    best_acc = 0
    for epoch in range(1, max_epoch+1):
        lr = scheduler.get_lr()[0]
        logging.info('epoch:{}\t lr:{}\t'.format(epoch, lr))
        start_time = time.time()
        train_loss = 0.0
        correct_train_sum = 0.0
        train_sum = 0.0

        # training
        net.train()
        with tqdm(train_data_loader, desc='training process:') as pbar:
            for img, label in pbar:
                img = img.cuda()
                label = label.cuda()
                optimizer.zero_grad()
                output, outputs = net(img)
                loss = TET_loss(outputs, label, criterion, )
                train_loss += loss.item()
                train_sum += label.numel()
                correct_train_sum += (output.argmax(dim=1) == label).float().sum().item()
                loss.backward()
                optimizer.step()
            train_accuracy = correct_train_sum / train_sum
            logging.info('[epoch: {}/{}]\t batch_avg_train_loss={:.6f}\t train_accuracy={:.6f}\n'
                        .format(epoch, max_epoch, train_loss / len(train_data_loader), train_accuracy))
            if (epoch + 1) % 10 == 0:
                torch.save({
                    'epoch': epoch + 1,
                    'net': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }, os.path.join(pt_dir, 'checkpoint_epoch_{}.pt'.format(epoch + 1)))
            scheduler.step()
        speed_per_epoch = time.time() - start_time
        logging.info('speed per epoch:{}'.format(speed_per_epoch))
