if __name__ == "__main__":
    import torch
    import torch.nn as nn
    import math
    import numpy as np
    import os
    import random
    from scipy.io import savemat
    from load_args import load_args
    from data_loader import data_loader
    from cifar10_resnet import resnet18
    from cifar100_wideresnet import *
    from train import train
    from evaluate import evaluate

    def main():
    
    
        args = load_args()

        # Check the availability of GPU.
        use_cuda = args.use_cuda and torch.cuda.is_available()
        device = torch.device("cuda:0" if use_cuda else "cpu")

        # Set the ramdom seed for reproducibility.
        if args.reproducible:
            torch.manual_seed(args.seed)
            np.random.seed(args.seed)
            random.seed(args.seed)
            if device != torch.device("cpu"):
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False

        # Load data, note we will also call the validation set as the test set.
        print('Loading data...')
        dataset = data_loader(dataset_name=args.dataset,
                              dataroot=args.dataroot,
                              batch_size=args.batchsize,
                              val_ratio=(args.val_ratio if args.validation else 0))
        train_loader = dataset[0]
        if args.validation:
            test_loader = dataset[1]
        else:
            test_loader = dataset[2]

        # Define the model and the loss function.
        if args.dataset == 'CIFAR10':
            net = resnet18()
        elif args.dataset == 'CIFAR100':
            net =  wrn(depth=28, num_classes=100, widen_factor=10, dropRate=0.3)
        else:
            raise ValueError("Unsupported dataset {0}.".format(args.dataset))
        net.to(device)
        criterion = nn.CrossEntropyLoss()

        ### define a function that compute milestones
        ####### compute the milestones
        n_train = len(train_loader)
        max_iter = args.train_epochs * n_train
        ### for step decay, we have to compute its milestones
        if args.optim_method == 'SGD_Step_Decay' or args.optim_method =='SGD_Step_Band':
            decay_rate = 1 / args.alpha
            n_outer = int(math.log(max_iter, decay_rate)/2)
            n_inner = max_iter // n_outer
            print('train_epoch, n_train, max_iter, n_outer, n_inner', args.train_epochs, n_train, max_iter,  n_outer, n_inner)
            for i in range(n_outer):
               args.milestones.append(n_inner*(i+1))
            print('milestones', args.milestones)

        ### for other methods, give the number of the loop we run and then we compute the milestones.
        if args.num_bandloop > 1 and args.milestones == []:
            if args.interval_mode == 'fixed':
               n_inner = max_iter // args.num_bandloop
               for i in range(args.num_bandloop):
                   args.milestones.append(n_inner*(i+1))
               print('milestones', args.milestones)
            if args.interval_mode == 'linear':
               initial_interval = 2*max_iter // (args.num_bandloop*(args.num_bandloop+1))
               interval_current = 0
               for i in range(arg.num_bandloop):
                   interval_current += (i+1)*initial_interval
                   args.milestones.append(interval_current)
               print('milestones', args.milestones)
            if args.interval_mode == 'exp-grow':
               initial_interval = max_iter // (2**args.num_bandloop-1)
               interval_current = 0
               for i in range(args.num_bandloop):
                   interval_current += 2**i*initial_interval
                   args.milestones.append(interval_current)
               print('milestones', args.milestones)
               

        # Train and evaluate the model.
        print("Training...")
        running_stats = train(args, train_loader, test_loader, net,
                              criterion, device)
        all_train_losses, all_train_accuracies = running_stats[:2]
        all_test_losses, all_test_accuracies = running_stats[2:]

        print("Evaluating...")
        final_train_loss, final_train_accuracy = evaluate(train_loader, net,
                                                          criterion, device)
        final_test_loss, final_test_accuracy = evaluate(test_loader, net,
                                                        criterion, device)

        # Logging results.
        print('Writing the results.')
        print('max_epoch, eta0, alpha, ratio, method, step-mode, epoch_mode, bandloop, interval-mode', args.train_epochs, args.eta0, args.alpha, args.ratio, args.optim_method,  args.step_mode, args.epoch_mode, args.num_bandloop, args.interval_mode)
        output = {'train_loss': all_train_losses, 'train_accuracy': all_train_accuracies, 'test_loss': all_test_losses, 'test_accuracy': all_test_accuracies}


        print('Finished.')

    main()
