
import os
import argparse
from models import *
from torch.optim.lr_scheduler import CosineAnnealingLR
import pprint
from utils import *
from config import *
from tqdm import tqdm
from datasets.data_loaders import data_loader
from torch.utils.tensorboard import SummaryWriter

parser = argparse.ArgumentParser(description='Variation-Bounded Losses for Learning with Noisy Labels')
# dataset settings
parser.add_argument('--dataset', type=str, default="cifar10", metavar='DATA', choices=['cifar10', 'cifar100'], help='dataset name')
parser.add_argument('--root', type=str, default="../database/", help='the data root')
parser.add_argument('--noise_type', type=str, default='symmetric', choices=['symmetric', 'asymmetric', 'dependent', 'human'], help='label noise type. using clean label by setting noise rate = 0')
parser.add_argument('--noise_rate', type=str, default='0.8', help='the noise rate 0~1. if using human noise, should set in [clean, worst, aggre, rand1, rand2, rand3, clean100, noisy100]')
parser.add_argument('--noise_method', type=str, default='method2', choices=['method1, method2'], 
                    help='noise generation method for symmetric and asymmetric noise'
                         'method1: which in "Asymmetric Loss Functions for Learning with Noisy Labels"'
                         'method2: which in "Active Negative Loss Functions for Learning with Noisy Labels"')
# initialization settings
parser.add_argument('--gpus', type=str, default='0', help='the used gpu id')
parser.add_argument('--seed', type=int, default=None, help='initial seed')
parser.add_argument('--trials', type=int, default=1, help='number of trials')
parser.add_argument('--test_freq', type=int, default=1, help='epoch frequency to evaluate the test set')
parser.add_argument('--save_model', default=False, action="store_true", help='whether to save trained model')
# parameter settings 
parser.add_argument('--loss', type=str, default='VCE', help='the loss function: VCE, NCEandVCE ... ')
parser.add_argument('--para', type=float, default=4, help='loss parameter')
parser.add_argument('--alpha', type=float, default=0, help='loss parameter')
parser.add_argument('--beta', type=float, default=10, help='loss parameter')
args = parser.parse_args()
args.dataset = args.dataset.lower()
if args.noise_type in ['dependent', 'human']:
    args.noise_method = '-' 
if args.dataset == 'cifar10': # change root by yourself
        args.root = args.root + '/CIFAR10'
elif args.dataset == 'cifar100':
    args.root = args.root + '/CIFAR100'

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
device = 'cuda' if torch.cuda.is_available()  else 'cpu'
print('We are using', device)
torch.backends.cudnn.benchmark = True

def run(args, i):
    if args.seed:
        seed_everything(args.seed + i)
    if args.dataset == 'cifar10':
        epochs = 120
        lr = 0.01
        batch_size = 128
        weight_decay = 5e-5
        model = CNN(type='CIFAR10').to(device)
    elif args.dataset == 'cifar100':
        epochs = 200
        lr = 0.1
        batch_size = 128
        weight_decay = 5e-6
        model = ResNet34(num_classes=100).to(device)
    else:
        raise NotImplementedError

    logger.info('\n' + pprint.pformat(args))

    train_loader, test_loader = data_loader(args=args, train_batch_size=batch_size, test_batch_size=batch_size*2, train_persistent=True, test_persistent=True)
    criterion = get_loss_config(args)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0)

    for epoch in tqdm(range(epochs), ncols=60, desc=args.loss + ' ' + args.dataset):
        model.train()
        total_loss = 0.
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            out = model(batch_x)
            loss = criterion(out, batch_y)
            if weight_decay:
                decay = sum(p.abs().sum() for p in model.parameters())
                loss += weight_decay * decay
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        if (epoch + 1) % args.test_freq == 0:
            test_acc1, _ = evaluate(test_loader, model, device)
            logger.info('Iter {}: loss={:.4f}, test_acc={:.4f}'.format(epoch, total_loss, test_acc1))
            summary_writer.add_scalar('test_acc1', test_acc1, epoch+1)
            summary_writer.add_scalar('loss', loss, epoch+1)
    if args.save_model:
        torch.save(model, results_path + '/model.pth')

    return test_acc1
    
if __name__ == "__main__":

    results_path = os.path.join('./results/', args.dataset, args.loss, args.noise_type + '_' + args.noise_rate)
    tag = f"/alpha={args.alpha}_beta={args.beta}_para={args.para}"
    results_path = results_path + tag
    if not os.path.exists(results_path):
        os.makedirs(results_path)
    logger = get_logger(results_path + '/result.log')
    summary_writer = SummaryWriter(log_dir=results_path)

    accs = []
    for i in range(args.trials):    
        acc = run(args, i)
        accs.append(acc)
    accs = torch.asarray(accs)*100
    logger.info(args.dataset+' '+args.loss+': %.2f±%.2f \n' % (accs.mean(), accs.std()))


    