__author__ = 'Qi'
# Created by on 1/10/22.
__author__ = 'Qi'
# Created by on 12/3/21.
import argparse
import os
import time
from datetime import datetime
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from mydataset import get_imbalanced_dataset,get_num_classes
import models
from myutils import ResultsLog, model_resume
from preprocess import get_transform_medium_scale_data
from myDataLoader import myDataLoader_imagenet, myDataLoader_iNaturalist18
import numpy as np

parser = argparse.ArgumentParser(description="Pytorch PLCOVER Training")
parser.add_argument('--results_dir', metavar="RESULTS_DIR", default='./TrainingResults', help = 'results dir')

parser.add_argument('--saveFolder', metavar = 'SAVE',  default='',help='save folder')
parser.add_argument('--res_filename', default='', type = str, help = 'results file name')
parser.add_argument('--dataset',  metavar='DATASET', default='cifar10',
                    help = 'dataset name or folder')

parser.add_argument('--model', metavar = 'MODEL', default='resnet', help ='model architecture')
parser.add_argument('--type', default='torch.cuda.FloatTensor',
                    help = 'types of tensor - e.g torch.cuda.FloatTensor')
parser.add_argument('--gpus',  default='0', help = 'gpus used for training - e.g 0,1,2,3')
parser.add_argument('--workers', default='8', type = int, metavar='N',
                    help='number of data loading workers (default:256)')
parser.add_argument('--batch-size', default=512, type=int, metavar='N',
                    help = 'mini-batch size (default:256)')
parser.add_argument('--optimizer', default='SGD',type=str, metavar='OPT',
                    help='optimizer function used')
parser.add_argument('--momentum', default=0, type = float, metavar="M",
                    help = "momentum parameter of SHB or SNAG")
parser.add_argument('--scale_size', default=32, type=int, help = 'image scale size for data preprocessing')
parser.add_argument('--input_size', default=32, type=int, help = 'the size of image. e.g. 32 for cifar10, 224 for imagenet')
parser.add_argument('--works', default=8, type=int, help = 'number of threads used for loading data')

parser.add_argument('--weight_decay', default=2e-4, type=float, help ='weight decay parameters')
parser.add_argument('--print_freq', '-p', default=50, type = int,
                    help = 'print frequency (default:50)')
# number of restart batches: restart_init_loop * batchsize
parser.add_argument('--restart_init_loop', default=5, type = int,
                    help = 'restart minibatch size = restart_init_loop * batchsize')
parser.add_argument('--start_training_time', type = float, help = 'Overall training start time')
parser.add_argument('--lamda', default=5, type = float, help = 'parameters of regularization')
parser.add_argument('--lamda1', default=5, type = float, help = 'initial lambd1 for the constraints such that lambda >= lambda1')
parser.add_argument('--lamda0', default=1e-4, type = float, help = 'lambda0 to make the DRO objective smooth')
parser.add_argument('--beta', default=0.9, type = float, help = 'momentum parameters for SCCMA')
parser.add_argument('--class_tau', default=0, type = float, help = 'class level dro')
parser.add_argument('--frozen_aside_fc', default=True, type=eval, choices=[True, False],
                    help='whether frozen the feature layers (First three block)')
parser.add_argument('--frozen_aside_last_block', default=False, type=eval, choices=[True, False],
                    help='whether frozen the feature layers (First three block)')

parser.add_argument('--pretrained', default=True, type=eval, choices=[True, False],
                    help='Wether use pretrained model')

#parser.add_argument('--boolean_flag',
#                   help='This is a boolean flag.',
#                    type=eval,
#                    choices=[True, False],
#                    default='True')
# boolean variable
parser.add_argument('--nesterov', default=False, type=eval, choices=[True, False],
                    help = 'This is used to determine whether we use SNAG')
parser.add_argument('--resume', default=False, type=eval, choices=[True, False],
                    help = 'Training from scratch (False) or from the saved check point')

###Tuning Parameters
parser.add_argument('--epochs', default=0, type=int,
                    help = 'number of total epochs')
parser.add_argument('--lr', default=0.1, type=float, metavar='WLR',
                    help='initial learning rate of w')

parser.add_argument('--plr', default=0.005, type = float, help = 'Dual Variable P')
parser.add_argument('--rho', default=1e-4, type = float, help = 'Constraint of DRO: rho')



# Loading Models Parameters
parser.add_argument('--resumed_epoch', default=0, type=int, help = "continuing training from a save check point")
parser.add_argument('--stages', default='1，2，3，4', type = str, help = 'start epochs of each stages')
parser.add_argument('--start_epochs', default=0, type=int, help = "start training epochs: default 0 in common training and start from loaded_epochs - 1 after loading the check point ")
parser.add_argument('--ith_init_run', default=0, type=int, help = "ith-initial weights")
parser.add_argument('--num_classes', default=10, type=int, help = "classes of different datasets")
parser.add_argument('--im_ratio', default=0.2, type=float, help = "imbalance ratio of datasets")
parser.add_argument('--DR', default=10, type=int, help = 'Decay Rate of Different Stages')
parser.add_argument('--binary', default=False, type=eval, choices=[True, False], help = 'Whether perform binary classification.')
parser.add_argument('--auc', default=False, type = eval, choices=[True, False], help = 'calculating AUC in binary classification')
parser.add_argument('--curlr', default=0.1, type=float,
                    help='current learning rate')
parser.add_argument('--lrlambda', default= 0.1, type=float,
                    help='current lambda rate')
parser.add_argument('--curbeta', default=0.1, type=float,
                    help='current learning rate')
parser.add_argument('--obj', default='ERM', type=str,
                    help='optimization objective of the loss')
parser.add_argument('--alg', default='PDSGD', type = str, choices=['FastDRO', 'PDSGD', 'SCCMA'], help = 'The choice of algorithms')


# Constrained DRO
parser.add_argument('--sampleType', default='uniform', type=str, help = 'Sampling methods')





def main():
    torch.manual_seed(123)
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()
    args.start_training_time = time.time()

    if args.saveFolder is '':
        args.saveFolder = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    args.results_dir = os.path.join(args.results_dir, args.saveFolder) # root_dir + save Folder
    if not os.path.exists(args.results_dir):
        os.makedirs(args.results_dir)
    results_file = os.path.join(args.results_dir, args.res_filename + '_results.csv')
    results = ResultsLog(results_file)


    if 'cuda' in args.type:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        args.gpus = [int(i) for i in args.gpus.split(',')]
        cudnn.benchmark = True
    else:
        args.gpus = None



    args.num_classes = get_num_classes(args)
    model = models.__dict__[args.model]
    model_new = model(pretrained = args.pretrained, num_classes = args.num_classes, data = args.dataset)
    print("length of model:", len(model_new.state_dict().keys()))

    if args.gpus and len(args.gpus) >= 1:
        model_new = torch.nn.DataParallel(model_new)


    if 'imagenet' in args.dataset:
        if 'amax' in os.uname()[1]:
            args.data_root = "/data/imagenet/imagenet/"
        train_loader = myDataLoader_imagenet(args, args.data_root, args.batch_size, 'train', num_workers=0, shuffle=False)
        val_loader = myDataLoader_imagenet(args, args.data_root, args.batch_size, 'val', num_workers=0, shuffle=False)
        test_loader = myDataLoader_imagenet(args, args.data_root, args.batch_size, 'test', num_workers=0, shuffle=False)
    elif 'iNaturalist18' in args.dataset:
        if 'amax' in os.uname()[1]:
            args.data_root = "/data/iNaturalist2018/"
        train_loader = myDataLoader_iNaturalist18(args, args.data_root, args.batch_size, 'train', num_workers=0,
                                             shuffle=False)
        val_loader = myDataLoader_iNaturalist18(args, args.data_root, args.batch_size, 'val', num_workers=0, shuffle=False)
        test_loader = None
    network_frozen(args, model_new)
    train_feat, val_feat, test_feat = None, None, None
    train_label, val_label, test_label = None, None, None
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.cuda()
        _, _, lb_feature = model_new(inputs)
        x = lb_feature.detach().cpu().numpy()
        targets = targets.numpy()

        if train_feat is None:
            train_feat = x
        else:
            train_feat = np.concatenate((train_feat, x))


        if train_label is None:
            train_label = targets
        else:
            train_label = np.concatenate((train_label, targets))
        print("train>>>>", train_feat.shape, train_label.shape)



    np.save('/data/qiqi/constrainedDRO/lb_feat_'+ args.dataset +'/train_feat.npy', train_feat)
    np.save('/data/qiqi/constrainedDRO/lb_feat_'+ args.dataset + '/train_label.npy', train_label)
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        inputs = inputs.cuda()
        _, _, lb_feature = model_new(inputs)
        x = lb_feature.detach().cpu().numpy()
        targets = targets.numpy()
        if val_feat is None:
            val_feat = x
        else:
            val_feat = np.concatenate((val_feat, x))

        if val_label is None:
            val_label = targets
        else:
            val_label = np.concatenate((val_label, targets))
        print("val>>>>", val_feat.shape, val_label.shape)


    np.save('/data/qiqi/constrainedDRO/lb_feat_'+ args.dataset + '/val_feat.npy', val_feat)
    np.save('/data/qiqi/constrainedDRO/lb_feat_'+ args.dataset + '/val_label.npy', val_label)

    if test_loader is not None:
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs = inputs.cuda()
            _, _, lb_feature = model_new(inputs)
            x = lb_feature.detach().cpu().numpy()
            targets = targets.numpy()
            if test_feat is None:
                test_feat = x
            else:
                test_feat = np.concatenate((test_feat, x))

            if test_label is None:
                test_label = targets
            else:
                test_label = np.concatenate((test_label, targets))
            print("test>>>>", test_feat.shape, test_label.shape)


        np.save('/data/qiqi/constrainedDRO/lb_feat_'+ args.dataset + '/test_feat.npy', test_feat)
        np.save('/data/qiqi/constrainedDRO/lb_feat_'+ args.dataset + '/test_label.npy', test_label)

def network_frozen(args, model):
    last_block_number = 0
    if args.model == "resnet152":
        last_block_number = 2
    elif args.model == 'resnet50':
        last_block_number = 2
    elif args.model == 'resnet10':
        last_block_number = 0

    last_block_pattern = 'layer4.' + str(last_block_number)

    # last_block_pattern = 'layer4.'
    if args.model == 'resnet32':
        last_block_pattern = 'layer3.4'



    for param_name, param in model.named_parameters():  # (self.networks[key]):  # frozen the first 3 block
        # import pdb; pdb.set_trace()
        # Freeze all parameters except self attention parameters
        # block components:
        #    -- layer1
        #    -- layer2
        #    -- layer3
        #    -- layer4
        #    -- fc
        #    -- fc
        if 'fc' not in param_name:
            if args.frozen_aside_last_block:
                if last_block_pattern not in param_name:
                    param.requires_grad = False
            else:
                param.requires_grad = False

    cnt_layers = 0
    for param_name, param in model.named_parameters():
        if param.requires_grad:
            cnt_layers += 1
            print(param_name)
    print("{} number of layers".format(cnt_layers))



if __name__ == '__main__':
    main()
