""""train models with coresets and per-batch approximations"""
import argparse
import os
import numpy as np
import glob

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import models
from wideresnet import wrn

from torch.utils.data import DataLoader, ConcatDataset
from warnings import simplefilter
from utils import *
# from train_and_record import get_acc

# add logger and tensorboard
import logging
from torch.utils.tensorboard import SummaryWriter

# import subset generator
from select.submod import get_orders_and_weights
from train_helpers import *

# ignore all future warnings
simplefilter(action='ignore', category=FutureWarning)
np.seterr(all='ignore')

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(models.__dict__[name]))

print(model_names)

parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 
                    choices=['resnet18', 'resnet34', 'resnet50', 'wideresnet', 'cnn', 'cnn_cifar'],
                    help='model architecture: ' + ' | '.join(model_names) +
                         ' (default: resnet18)')
parser.add_argument('--dataset', default='cifar10', help='dataset',
                    choices=(
                        'cifar10', 'cifar100', 'cifar100sup', 'caltech256', 'cinic10', 'tinyimagenet', 
                        'stanfordcars', 'waterbirds', 'cmnist'))
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=128, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', '-m', type=float, metavar='M', default=0.9,
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--print-freq', '-p', default=100, type=int,
                    metavar='N', help='print frequency (default: 20)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', default='output/cifar10_transform_resnet20_1.0_128/',
                    help='path to the pre-trained model')
parser.add_argument('--half', dest='half', action='store_true',
                    help='use half-precision(16-bit) ')
parser.add_argument('--save-dir', dest='save_dir',
                    help='The directory used to save the trained models',
                    default='./save_temp', type=str)
parser.add_argument('--save-every', dest='save_every',
                    help='Saves checkpoints at every specified number of epochs',
                    type=int, default=20)  # default=10)
parser.add_argument('--gpu', default='7', type=str, help='The GPU to be used')
parser.add_argument('--prune_method', type=str, help='data pruning method', default='random',
                    choices=[
                        'random', 'upweight', 'subsample', 'subsample-group', 'el2n', 'forget', 'traject', 
                        'local_rewind', 'local_next', 'traject-most', 'traject-avg', 'traject-transfer', 'sample-gradnorm'])
parser.add_argument('--local_method', type=str, help='data pruning method', default='traject',
                    choices=['max-grad-norm', 'traject'])
parser.add_argument('--before_select', type=str, help='data to train on before data selection happends', default='full',
                    choices=['full', 'rand', 'easy', 'hard', 'sample-pos', 'sample-neg', 'traject'])
parser.add_argument('--before_select_metric', type=str, help='metric to determine easy vs hard examples', default='acc',
                    choices=['el2n', 'acc'])
parser.add_argument('--per_class', action='store_false', help='select equal number of examples per class')
parser.add_argument('--rewind', action='store_true', help='rewind to the start epoch instead of the initialization')
parser.add_argument('--local_interval', default=4, type=int, help="interval for local data selection")
parser.add_argument('--local_freq', default=1, type=int, help="interval for recording the gradients as parts of trajectories")
parser.add_argument('--num_avg', default=2, type=int, help="number of runs used for averaging")
parser.add_argument('--end_epoch', default=200, type=int, help="when to end local selection")
parser.add_argument('--prune_size', type=float, help='fraction of data to prune at each run', default=0.1)
parser.add_argument('--preselect_size', type=float, help='fraction of data selected for training before selection', default=1.0)
parser.add_argument('--hard_offset_size', type=float, help='fraction of hardest data to exclude', default=0.05)
parser.add_argument('--noisy_label_size', type=float, help='fraction of noisy labels', default=0.)
parser.add_argument('--idn', action='store_true', help='use instance-dependent noise')
parser.add_argument('--ig', type=str, help='ig method', default='sgd', choices=['sgd, adam, adagrad'])
parser.add_argument('--lr_schedule', '-lrs', type=str, help='learning rate schedule', default='plateau',
                    choices=['mile', 'exp', 'cnt', 'step', 'cosine', 'plateau'])
parser.add_argument('--gamma', type=float, default=0.2, help='learning rate decay parameter')
parser.add_argument('--warm', '-w', dest='warm_start', action='store_true', help='warm start learning rate ')

parser.add_argument('--subset_weight', default='uniform', choices=['uniform', 'weight', 'minibatch'], help='')
parser.add_argument('--no_transforms', action='store_true', help='disable random transformations of the data')
parser.add_argument('--strong_aug', action='store_true', help='use strong data augmentation (RandAug)')
parser.add_argument('--add_strong_aug', action='store_true', help='add strong data augmentation (RandAug)')
parser.add_argument('--match_num_iter_preselect', action='store_true', help='use the same number of iterations for the warmup period')
parser.add_argument('--match_num_iter', action='store_true', help='use the same number of iterations for the entire training')
parser.add_argument('--seed', default=11111111, type=int, help="random seed")
parser.add_argument('--run', default=0, type=int, help="run_num")

parser.add_argument('--cl', action='store_true', help='compare to curriculum learning')
parser.add_argument('--longtail', action='store_true', help='compare to curriculum learning')
parser.add_argument('--imb_factor', type=float, default=0.1, help='imbalance factor')

parser.add_argument('--small_only', action='store_true', help='select small clusters first')
parser.add_argument('--sample_inverse', action='store_true', help='sample examples by the inverse of the size of the clusters they are in')
parser.add_argument('--sqrt', action='store_true', help='use the square root of the metric')
parser.add_argument('--sample_power', default=1, type=int, help='use the square root of the metric')
parser.add_argument('--large_only', action='store_true', help='select small clusters first')
parser.add_argument('--upsample_clusters', action='store_true', help='upsample by clusters')
parser.add_argument('--upweight_clusters', action='store_true', help='upsample by clusters')
parser.add_argument('--reweight_cluster_norm', action='store_true', help='upsample by clusters')
parser.add_argument('--sample_inverse_full', action='store_true', help='sample full data by the inverse of the size of the clusters they are in')

parser.add_argument('--small_only_all', action='store_true', help='select small clusters first')
parser.add_argument('--sample_inverse_all', action='store_true', help='sample examples by the inverse of the size of the clusters they are in')
parser.add_argument('--prune_largest', action='store_true', help='drop examples from the largest clusters')
parser.add_argument('--prune_largest_iterative', action='store_true', help='drop examples from the largest clusters iteratively')
parser.add_argument('--max_cluster_size', default=2, type=int, help="maximum number of examples in one cluster")

parser.add_argument('-pc', '--p_correlation', type=float, default=0.995,
                    help="Ratio of majority group size to total size")

logger = logging.getLogger(__name__)

def main(args):
    global best_prec1, best_loss

    # if args.longtail:
    #     args.per_class = False

    if args.prune_method[:5] == 'local':
        grd = f'_{args.prune_method}_{args.local_method}_{args.prune_size:.4f}'
        grd += f'_prop' if not args.per_class else ''
        if args.small_only:
            grd += f'_small-only'
        if args.large_only:
            grd += f'_large-only'
        if args.sample_inverse:
            grd += f'_sample-inverse'
            if args.sample_power > 1:
                grd += f'_power{args.sample_power}'
        if args.small_only_all:
            grd += f'_small-only-all'
        if args.sample_inverse_all:
            grd += f'_sample-inverse-all'
        if args.prune_largest:
            grd += f'_prune-largest-{args.max_cluster_size}'
        if args.prune_largest_iterative:
            grd += f'_prune-largest-iterative'
        if args.sqrt:
            grd += '_sqrt'
        if args.upsample_clusters:
            grd += f'_upsample-clusters'
        if args.upweight_clusters:
            grd += f'_upweight-clusters'
        if args.reweight_cluster_norm:
            grd += f'_reweight_cluster_norm'
        if args.sample_inverse_full:
            grd += f'_sample-inverse-full'
        grd += f'_start_{args.start_epoch}_interval_{args.local_interval}'
        grd += f'_every_{args.local_freq}' if (args.local_freq != 1) else ''
        grd += f'_end_{args.end_epoch}' if (args.end_epoch != args.epochs) else ''
        grd += f'_{args.subset_weight}' if ((args.local_method == 'traject') and (args.subset_weight != 'uniform')) else ''
    else:
        grd = f'_{args.prune_method}'
        if args.prune_method == 'traject-avg':
            grd += f'_{args.num_avg}runs'
        grd += f'_start_{args.start_epoch}'
        grd += f'_interval_{args.local_interval}' if (args.local_interval < args.end_epoch-args.start_epoch) else ''
        grd += f'_every_{args.local_freq}' if (args.local_freq != 1) else ''
        grd += f'_end_{args.end_epoch}' if (args.end_epoch != args.epochs) else ''
    if args.resume:
        args.save_dir = os.path.join(args.save_dir, args.resume.split('/')[-1])
        if args.prune_method == 'traject-transfer':
            args.save_dir += f'_{args.prune_method}_seed_{args.seed}'
        else:
            args.save_dir += '_train_selected_from_init'
            if args.start_epoch > 0 and args.rewind:
                args.save_dir += f'_rewind_{args.start_epoch}'
        train_size = 1.0 - args.prune_size
    else:
        folder = f'/{args.dataset}'
        folder += f'_pc_{args.p_correlation:.3f}' if (args.dataset == 'cmnist') else ''
        folder += f'_longtail_{args.imb_factor:.3f}' if args.longtail else ''
        folder += f'_transform' if not args.no_transforms else ''
        folder += f'_strong' if args.strong_aug else ''
        folder += f'_add_strong' if args.add_strong_aug else ''
        if args.idn:
            idn_str = '_idn'
        else:
            idn_str = ''
        folder += f'{idn_str}_noisy_{args.noisy_label_size:.1f}' if args.noisy_label_size > 0 else ''
        folder += f'_{args.before_select_metric}' if (args.before_select not in ['full', 'rand', 'traject']) and (args.before_select_metric != 'el2n') else ''
        folder += f'_{args.before_select}_{args.preselect_size:.3f}' if args.before_select != 'full' else ''
        folder += f'_match_before' if args.match_num_iter_preselect else ''
        folder += f'_match' if args.match_num_iter else ''
        folder += f'_offset_{args.hard_offset_size:.2f}' if args.hard_offset_size > 0 and (args.before_select not in ['full', 'rand', 'traject']) else ''
        train_size = 1.0 - args.prune_size
        args.save_dir += f'{folder}_{args.arch}_ts_{train_size:.4f}_bs_{args.batch_size}_{args.ig}_lr_{args.lr:.3f}_{args.lr_schedule}'
        args.save_dir += f'_gamma_{args.gamma:.1f}' if args.gamma != 0.2 else ''
        args.save_dir += f'_epochs_{args.epochs}_seed_{args.seed}{grd}'
    

    print(args.save_dir)
    os.makedirs(args.save_dir)

    args.logger = set_logger(args, logger)
    args.logger.info(f'--------- prune_size: {args.prune_size}, method: {args.ig}, moment: {args.momentum}, '
          f'lr_schedule: {args.lr_schedule}---------------')

    args.writer = SummaryWriter(args.save_dir)

    set_seed(args)

    if args.dataset == 'cifar100':
        args.class_num = 100
    elif args.dataset == 'cifar100sup':
        args.class_num = 20
    elif args.dataset == 'caltech256':
        args.class_num = 257
    elif args.dataset == 'stanfordcars':
        args.class_num = 196
    elif args.dataset == 'tinyimagenet':
        args.class_num = 200
    elif args.dataset == 'waterbirds':
        args.class_num = 2
    elif args.dataset == 'cmnist':
        args.class_num = 5
    else:
        args.class_num = 10
    
    if args.arch == 'wideresnet':
        model = torch.nn.DataParallel(wrn(args.class_num))
    elif args.dataset == 'caltech256':
        from torchvision.models import resnet18
        model = torch.nn.DataParallel(resnet18(pretrained=True))
        logger.info("==> Reinitializing the classifier..")
        num_ftrs = model.module.fc.in_features
        model.module.fc = nn.Linear(num_ftrs, args.class_num)
    elif args.dataset in ['tinyimagenet', 'stanfordcars']:
        from torchvision.models import resnet18
        model = torch.nn.DataParallel(resnet18(num_classes=args.class_num))
    elif args.dataset in ['waterbirds']:
        from torchvision.models import resnet50
        model = resnet50(pretrained=True)
        model.fc = nn.Linear(2048, args.class_num)
        model = torch.nn.DataParallel(model)
    else:
        model = torch.nn.DataParallel(models.__dict__[args.arch](num_classes=args.class_num))
    model.cuda()

    cudnn.benchmark = True

    train_transform, val_transform, strong_train_transform = get_transforms(args)
    if args.no_transforms:
        indexed_dataset = IndexedDataset(args.dataset, train=True, transform=val_transform, noisy=args.noisy_label_size, seed=args.seed, idn=args.idn, args=args)
    else:
        indexed_dataset = IndexedDataset(args.dataset, train=True, transform=train_transform, noisy=args.noisy_label_size, seed=args.seed, idn=args.idn, args=args)

    args.train_num = len(indexed_dataset)

    indexed_loader = DataLoader(
        indexed_dataset,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    if args.dataset == 'waterbirds':
        val_loader = [torch.utils.data.DataLoader(
            IndexedDataset(args.dataset, transform=val_transform, train=False, args=args, group=group),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True) 
            for group in [[0,0], [0,1], [1,0], [1,1]]]
    else:
        val_loader = torch.utils.data.DataLoader(
            IndexedDataset(args.dataset, transform=val_transform, train=False, args=args),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)

    train_val_loader = torch.utils.data.DataLoader(
        IndexedDataset(args.dataset, transform=val_transform, train=True, noisy=args.noisy_label_size, seed=args.seed, idn=args.idn, args=args),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.half:
        model.half()
        train_criterion.half()
        val_criterion.half()

    if args.evaluate:
        if args.dataset == 'waterbirds':
            validate_by_group(args, val_loader, model, val_criterion)
        else:
            validate(val_loader, model, val_criterion)
        return

    optimizer, lr_scheduler = get_optimizer_and_scheduler(args, model)

    if args.warm_start:
        args.logger.info('Warm start learning rate')
        lr_scheduler_f = GradualWarmupScheduler(optimizer, 1.0, 20, lr_scheduler)
    else:
        args.logger.info('No Warm start')
        lr_scheduler_f = lr_scheduler        
        
    best_prec1, best_loss = 0, 1e10
    start_epoch = 0
    
    # load the model initialization and selected examples from an existing experiment
    if args.resume and (args.prune_method != 'traject-transfer'):
        args.logger.info(f'Training on {args.prune_method} selected subset')
        if args.start_epoch > 0 and args.rewind:
            args.logger.info(f'Rewinding the network to a checkpoint at epoch {args.start_epoch}')
            ckpt_path = os.path.join(args.resume, 'start.pth')
            checkpoint = torch.load(ckpt_path, map_location='cpu')
            best_prec1 = checkpoint['best_prec1']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            lr_scheduler_f.load_state_dict(checkpoint['scheduler_state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(ckpt_path, checkpoint['epoch']))
            start_epoch = args.start_epoch
        if args.prune_method == 'traject':
            labels = indexed_dataset.dataset.targets
            traject = []
            for checkpoint_epoch in range(args.start_epoch, args.epochs, 4):
                subset_file = f'output_record/{args.dataset}_transform_resnet18_1.0_128_seed_11111111_lr_0.1_record/p-y_epoch_{checkpoint_epoch}_0.pth'
                print("=> loading stats from '{}'".format(subset_file))
                stats = torch.load(subset_file)
                traject.append(stats)
            traject = np.concatenate(traject, axis=1)
            print(np.shape(traject))
            subset, subset_weight, _, _, _, _ = get_orders_and_weights(
                B, traject, 'euclidean', smtk=0, no=0, y=labels, stoch_greedy=0,
                equal_num=args.per_class)
            torch.save({'subset': subset, 'subset_weight': subset_weight}, os.path.join(args.save_dir, f'subset.pth'))
        elif args.prune_method[:5] == 'local':
            selected = np.zeros(50000)
            if os.path.isdir(args.resume):
                for subset_file in glob.glob(f'{args.resume}/subset_epoch_*.pth'):
                    subset = torch.load(subset_file)
                    selected[subset] += 1
                args.logger.info("=> loaded subset from {}".format(args.resume))
                subset = np.where(selected>0)[0]
            else:
                raise NotADirectoryError("=> no checkpoint found at '{}'".format(args.resume))
        else:
            raise NotImplementedError
        indexed_subset = torch.utils.data.Subset(indexed_dataset, indices=subset)
        train_loader = DataLoader(
            indexed_subset,
            batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True, drop_last=True)
        end_epoch = args.epochs
        train_size = len(subset)
        args.logger.info(f'{train_size} examples in the loaded subset')
        torch.save(subset, os.path.join(args.save_dir, f'subset_epoch_0.pth'))
    else:
        args.logger.info('Training on all the data')
        train_loader = indexed_loader
        end_epoch = args.epochs

    # select a subset of training data for the burn-in stage
    if args.before_select != 'full':
        preselect_B = int(args.train_num * args.preselect_size)
        args.logger.info(f"=> Using {preselect_B} examples for the warm-up period")
        if args.before_select == 'rand':
            args.logger.info("=> randomly selecting examples")
            subset = np.arange(0, args.train_num)
            np.random.shuffle(subset)
            subset = subset[:preselect_B]
        elif args.before_select == 'traject':
            dirname = f'output_local_traject_warmup_{args.start_epoch}/cifar10_transform_resnet18_ts_{train_size:.3f}_bs_128_sgd_lr_0.1_mile_epochs_200_seed_{args.seed}_local_next_traject_{args.prune_size:.3f}_start_{args.start_epoch}_interval_{args.local_interval}'
            print(dirname)
            for subset_file in glob.glob(f'{dirname}/subset_epoch_{args.start_epoch}.pth'):
                args.logger.info("=> loading subset from '{}'".format(subset_file))
                subset = torch.load(subset_file)
        else:
            score = np.zeros(args.train_num)
            if args.before_select_metric == 'el2n':
                for subset_file in glob.glob(f'output_record_el2n/{args.dataset}_transform_resnet18_1.0_128_seed_*_lr_0.1_record/subset_epoch_20_0.pth'):
                    print("=> loading stats from '{}'".format(subset_file))
                    stats = torch.load(subset_file)
                    score += stats['el2n']
                if args.before_select == 'hard':
                    subset = np.argsort(-score)
                    if args.hard_offset_size > 0:
                        subset = subset[int(len(subset)*args.hard_offset_size):]
                elif args.before_select == 'easy':
                    subset = np.argsort(score)
                elif args.before_select == 'sample-pos':
                    subset = np.random.choice(args.train_num, preselect_B, replace=False, p=score/np.sum(score))
                else:
                    score = np.amax(score) - score + 1e-6
                    subset = np.random.choice(args.train_num, preselect_B, replace=False, p=score/np.sum(score))
            else:
                for subset_file in glob.glob(f'output_record_acc/{args.dataset}_transform_resnet18_1.0_128_seed_*_lr_0.1_record/acc_epoch_*.pth'):
                    print("=> loading stats from '{}'".format(subset_file))
                    stats = torch.load(subset_file)
                    score += stats
                if args.before_select == 'hard':
                    subset = np.argsort(score)
                    if args.hard_offset_size > 0:
                        subset = subset[int(len(subset)*args.hard_offset_size):]
                elif args.before_select == 'easy':
                    subset = np.argsort(-score)
                elif args.before_select == 'sample-neg':
                    subset = np.random.choice(args.train_num, preselect_B, replace=False, p=score/np.sum(score))
                else:
                    score = np.amax(score) - score + 1e-6
                    subset = np.random.choice(args.train_num, preselect_B, replace=False, p=score/np.sum(score))
            subset = subset[:preselect_B]
        torch.save(subset, os.path.join(args.save_dir, f'subset_epoch_{start_epoch}.pth'))
        indexed_subset = torch.utils.data.Subset(indexed_dataset, indices=subset)
        train_loader = DataLoader(
            indexed_subset,
            batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True, drop_last=True)
    weight = None
    train_loss, test_loss = np.zeros(end_epoch), np.zeros(end_epoch)
    train_acc, test_acc = np.zeros(end_epoch), np.zeros(end_epoch)
    train_time, data_time = np.zeros(end_epoch), np.zeros(end_epoch)
    labels = indexed_dataset.targets

    epoch = start_epoch
    rewinded = 0
    stop_iterative = 0
    resample_idx = np.arange(args.train_num)
    grad_norm = np.zeros(len(indexed_dataset))
    traject = []
    B = int(train_size * args.train_num)
    first_learned = np.ones(args.train_num) * args.epochs
    times_forgot = np.zeros(args.train_num)
    learned = np.zeros(args.train_num)
    while epoch < end_epoch:
        if args.cl:
            if epoch < args.start_epoch:
                args.subset_size = 1
                steps = [5, 10, 15, 20, 30, 40, 60, 90, 140, 210, 300]
            else:
                args.subset_size = max(0.2, 0.85 ** (epoch-args.start_epoch+1))
                if args.lr_schedule == 'cosine':
                    if epoch in steps[:-1]:
                        i = steps.index(epoch)
                        lr_scheduler_f = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps[i+1]-epoch, eta_min=5e-4)
                args.local_interval = args.start_epoch
                args.local_freq = args.start_epoch
        if (args.prune_method[:5] == 'local') and not args.resume:
            if not os.path.exists(os.path.join(args.save_dir, 'start.pth')) and (epoch == args.start_epoch):
                args.logger.info(f'Starting local training at epoch {epoch}...')
                save_checkpoint({'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': lr_scheduler_f.state_dict(),
                    'best_prec1': best_prec1,
                    'best_loss': best_loss,
                    }, filename=os.path.join(args.save_dir, 'start.pth'))
                if args.prune_method == 'local_rewind':
                    args.logger.info(f'Saving checkpoint at epoch {epoch} for rewinding...')
                    save_checkpoint({'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': lr_scheduler_f.state_dict(),
                        'best_prec1': best_prec1,
                        'best_loss': best_loss,
                        }, filename=os.path.join(args.save_dir, 'checkpoint.pth'))
            if epoch >= (args.start_epoch - args.local_interval) and epoch <= args.end_epoch:
                if not rewinded or (args.prune_method == 'local_next' and args.start_epoch == 0):                  
                    if epoch > (args.start_epoch - args.local_interval) and ((epoch-args.start_epoch) % args.local_freq == 0):
                        save_checkpoint({'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': lr_scheduler_f.state_dict(),
                        'best_prec1': best_prec1,
                        'best_loss': best_loss,
                        }, filename=os.path.join(args.save_dir, f'checkpoint_epoch_{epoch}.pth'))
                        args.logger.info(f'Recording gradients after training for {epoch} epoch(s) ...')
                        _, preds = get_losses_and_preds(args, train_val_loader, model, train_criterion)
                        preds -= np.eye(args.class_num)[labels]
                        if epoch > 0:
                            traject.append(preds)
                        else:
                            torch.save(preds, os.path.join(args.save_dir, f'p-y_init.pth'))
                        el2n = np.linalg.norm(preds, axis=-1)
                        grad_norm = np.maximum(grad_norm, el2n)
                    if (epoch >= args.start_epoch) and ((epoch-args.start_epoch) % args.local_interval == 0) and (epoch > 0):
                        args.logger.info(f'Pruning data before epoch {epoch}...')
                        traject = np.concatenate(traject, axis=1)
                        if args.local_method == 'max-grad-norm':
                            subset = np.argsort(-grad_norm)[:B]
                        elif args.local_method == 'traject':
                            equal = True if args.dataset != 'stanfordcars' or args.per_class else False
                            subset, subset_weight, _, _, _, _, cluster = get_orders_and_weights(
                                B, traject[resample_idx], 'euclidean', y=np.array(labels)[resample_idx], equal_num=equal, return_cluster=True, args=args)
                            subset = resample_idx[subset]      
                            if args.dataset == 'waterbirds':
                                groups = np.array(indexed_dataset.groups)[subset]
                                groups, sizes = np.unique(groups, return_counts=True)
                                args.logger.info(f'Training sizes')
                                for group, size in zip(groups, sizes):
                                    args.logger.info(f'Group {group}: {size}')
                        if args.prune_largest:
                            subset = []
                            cluster = np.concatenate(cluster)
                            cluster_idx, cluster_sizes = np.unique(cluster, return_counts=True)
                            for c, cs in zip(cluster_idx, cluster_sizes):
                                print(c, cs)
                                if cs > args.max_cluster_size:
                                    subset.append(np.random.choice(np.where(cluster == c)[0], args.max_cluster_size, replace=False))
                                else:
                                    subset.append(np.where(cluster==c)[0])
                            subset = np.concatenate(subset)
                            args.logger.info(f'Subset size {len(subset)}')
                        elif args.prune_largest_iterative:
                            resample = []
                            for c in range(len(cluster)):
                                clusters_c = cluster[c]
                                cluster_idx, cluster_sizes = np.unique(clusters_c, return_counts=True)
                                ranked_cluster_idx = cluster_idx[np.argsort(-cluster_sizes)]     
                                resample.append(np.where(clusters_c != ranked_cluster_idx[0])[0] + len(clusters_c)*c)
                                resample.append(np.random.choice(np.where(clusters_c == ranked_cluster_idx[0])[0], cluster_sizes[np.argsort(-cluster_sizes)[1]], replace=False) + len(clusters_c)*c)
                            resample = np.concatenate(resample)
                            resample_idx = resample_idx[resample]
                            logger.info(f'Size after pruning: {len(resample_idx)}')
                            if len(resample_idx) < B:
                                stop_iterative = 1
                                try:
                                    subset = subset_last
                                except:
                                    subset = resample_idx
                                subset_weight = np.ones(len(subset), dtype=np.float64)
                            else:
                                subset_last = resample_idx
                        elif args.upsample_clusters:
                            subset = []
                            cluster = np.concatenate(cluster)
                            cluster_idx, cluster_sizes = np.unique(cluster, return_counts=True)
                            for c, cs in zip(cluster_idx, cluster_sizes):
                                if cs > (1./train_size):
                                    subset.append(np.where(cluster == c)[0])
                                else:
                                    subset.append(np.random.choice(np.where(cluster == c)[0], int(1./train_size), replace=True))
                            subset = np.concatenate(subset)
                            args.logger.info(f'Subset size {len(subset)}')
                        elif args.sample_inverse_full:
                            cluster = np.concatenate(cluster)
                            cluster_idx, cluster_sizes = np.unique(cluster, return_counts=True)
                            prob = 1. / cluster_sizes[cluster]
                            indices_by_class = np.concatenate([np.where(np.array(labels)==c)[0] for c in range(args.class_num)])
                            subset = np.random.choice(indices_by_class, len(prob), p=prob/np.sum(prob), replace=True)
                            args.logger.info(f'Subset size {len(subset)}')
                            if args.dataset == 'waterbirds':
                                groups = np.array(indexed_dataset.groups)[subset]
                                groups, sizes = np.unique(groups, return_counts=True)
                                args.logger.info(f'Training sizes')
                                for group, size in zip(groups, sizes):
                                    args.logger.info(f'Group {group}: {size}')

                        if args.subset_weight == 'weight':
                            weight = np.zeros(args.train_num)
                            weight[subset] = subset_weight / np.sum(subset_weight) * len(subset)
                            weight = torch.from_numpy(weight).cuda()
                        if args.upweight_clusters:
                            subset = np.arange(args.train_num)
                            cluster = np.concatenate(cluster)
                            cluster_idx, cluster_sizes = np.unique(cluster, return_counts=True)
                            prob = 1. / cluster_sizes[cluster]
                            indices_by_class = np.concatenate([np.where(np.array(labels)==c)[0] for c in range(args.class_num)])
                            weight = prob[np.argsort(indices_by_class)]
                            weight = weight / np.sum(weight) * len(subset)
                            weight = torch.from_numpy(weight).cuda()
                        elif args.reweight_cluster_norm:
                            subset = np.arange(args.train_num)
                            cluster = np.concatenate(cluster)
                            cluster_idx, cluster_sizes = np.unique(cluster, return_counts=True)
                            indices_by_class = np.concatenate([np.where(np.array(labels)==c)[0] for c in range(args.class_num)])
                            weight = np.zeros(args.train_num)
                            for ci, cs in zip(cluster_idx, cluster_sizes):
                                cluster_indices = indices_by_class[np.where(cluster==ci)[0]]
                                norm_sum = np.sum(el2n[cluster_indices])
                                weight[cluster_indices] = 1 / norm_sum
                            weight = weight / np.sum(weight) * len(subset)
                            weight = torch.from_numpy(weight).cuda()

                        if args.start_epoch == 0 and (not os.path.exists(os.path.join(args.save_dir, 'subset_epoch_0.pth'))):
                            torch.save(traject, os.path.join(args.save_dir, f'p-y_epoch_0.pth'))
                            torch.save(subset, os.path.join(args.save_dir, f'subset_epoch_0.pth'))
                            if args.local_method == 'traject':
                                torch.save(weight, os.path.join(args.save_dir, f'weight_epoch_0.pth'))
                                torch.save(cluster, os.path.join(args.save_dir, f'cluster_epoch_0.pth'))
                        else:
                            torch.save(traject, os.path.join(args.save_dir, f'p-y_epoch_{epoch}.pth'))
                            torch.save(subset, os.path.join(args.save_dir, f'subset_epoch_{epoch}.pth'))                            
                            if args.local_method == 'traject':
                                torch.save(weight, os.path.join(args.save_dir, f'weight_epoch_{epoch}.pth'))
                                torch.save(cluster, os.path.join(args.save_dir, f'cluster_epoch_0.pth'))

                        if args.prune_method == 'local_rewind' or (args.prune_method == 'local_next' and args.start_epoch == 0 and (not os.path.exists(os.path.join(args.save_dir, f'subset_epoch_{args.local_interval}.pth')))):
                            if args.prune_method == 'local_rewind':
                                ckpt_path = os.path.join(args.save_dir, 'checkpoint.pth')
                            else:
                                ckpt_path = os.path.join(args.save_dir, 'start.pth')
                            checkpoint = torch.load(ckpt_path, map_location='cpu')
                            epoch = checkpoint['epoch']
                            best_prec1 = checkpoint['best_prec1']
                            best_loss = checkpoint['best_loss']
                            model.load_state_dict(checkpoint['state_dict'])
                            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                            lr_scheduler_f.load_state_dict(checkpoint['scheduler_state_dict'])
                            if (args.prune_largest_iterative) and (not stop_iterative):
                                traject = []
                            else:
                                rewinded = 1
                            args.logger.info(f'Rewinded back to epoch {epoch}...')

                        if args.strong_aug and (epoch == args.start_epoch):
                            logger.info("Use strong augmentations")
                            indexed_dataset = IndexedDataset(args.dataset, train=True, transform=strong_train_transform, noisy=args.noisy_label_size, seed=args.seed, args=args)
                        indexed_subset = torch.utils.data.Subset(indexed_dataset, indices=subset)
                        if args.add_strong_aug and (epoch == args.start_epoch):
                            logger.info("Add strong augmentations")
                            indexed_subset = ConcatDataset([indexed_subset, torch.utils.data.Subset(IndexedDataset(args.dataset, train=True, transform=strong_train_transform, noisy=args.noisy_label_size, seed=args.seed, args=args), indices=subset)])
                        train_loader = DataLoader(
                            indexed_subset,
                            batch_size=args.batch_size, shuffle=True,
                            num_workers=args.workers, pin_memory=True, drop_last=True)
                        grad_norm = np.zeros(len(indexed_dataset))
                        traject = []
                elif rewinded and (epoch > args.start_epoch) and ((epoch-args.start_epoch) % args.local_interval == 0):
                    args.logger.info(f'Rewinded training completed.')
                    rewinded = 0
                    args.logger.info(f'Saving checkpoint at epoch {epoch} for rewinding...')
                    save_checkpoint({'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': lr_scheduler_f.state_dict(),
                        'best_prec1': best_prec1,
                        'best_loss': best_loss,
                    }, filename=os.path.join(args.save_dir, 'checkpoint.pth'))
            if epoch == args.end_epoch:
                args.logger.info(f'Ending local training at epoch {epoch}...')
        elif args.prune_method[-4:] == 'comb' and epoch in [0,10,20]:
            if epoch == 0:
                subset_dict = {}
                score = np.zeros(args.train_num)
                for subset_file in glob.glob(f'output_record_acc/{args.dataset}_transform_resnet18_1.0_128_seed_*_lr_0.1_record/acc_epoch_*.pth'):
                    print("=> loading stats from '{}'".format(subset_file))
                    stats = torch.load(subset_file)
                    score += stats
                subset = np.argsort(score)
                if args.hard_offset_size > 0:
                    subset = subset[int(len(subset)*args.hard_offset_size):]
                subset_dict['hard'] = subset[:int(B/2)]
                subset = np.argsort(-score)
                subset_dict['easy'] = subset[:int(B/2)]
                subset = subset_dict[args.prune_method[:4]]
            elif epoch == 10:
                subset = subset_dict[args.prune_method[5:9]]
            elif epoch == 20:
                subset = np.concatenate([subset_dict['hard'], subset_dict['easy']])
            indexed_subset = torch.utils.data.Subset(indexed_dataset, indices=subset)
            train_loader = DataLoader(
                indexed_subset,
                batch_size=args.batch_size, shuffle=True,
                num_workers=args.workers, pin_memory=True, drop_last=True)
        elif (epoch >= args.start_epoch) and ((epoch-args.start_epoch) % args.local_interval == 0):
            if args.prune_method == 'random':
                args.logger.info(f'Training on {B} randomly selected examples')
                subset = np.arange(0, args.train_num)
                np.random.shuffle(subset)
                subset = subset[:B]
            elif args.prune_method == 'subsample':
                args.logger.info("=> randomly selecting examples equally from each class")
                indices = []
                for c in np.unique(labels):
                    class_indices = np.where(labels == c)[0]
                    indices_per_class = np.random.choice(class_indices, size=int(B / args.class_num), replace=False)
                    indices.append(indices_per_class)
                subset = np.concatenate(indices)
            elif args.prune_method == 'upweight':
                args.logger.info("=> weighting each sample by the inverse of group sizes")
                subset = np.arange(0, args.train_num)
                weight = torch.tensor([1./np.sum(np.array(indexed_dataset.groups) == g) for g in indexed_dataset.groups]).cuda()
            elif args.prune_method == 'subsample-group':
                groups, sizes = np.unique(indexed_dataset.groups, return_counts=True)
                smallest_size = np.amin(sizes)
                args.logger.info(f"=> randomly selecting {smallest_size} examples equally from each class")
                indices = []
                for c in np.unique(groups):
                    class_indices = np.where(indexed_dataset.groups == c)[0]
                    indices_per_class = np.random.choice(class_indices, size=smallest_size, replace=False)
                    indices.append(indices_per_class)
                subset = np.concatenate(indices)
            elif args.prune_method == 'sample-gradnorm':
                args.logger.info(f"=> sampling {B} examples proportional to the norm of gradient")
                _, preds = get_losses_and_preds(args, train_val_loader, model, train_criterion)
                preds -= np.eye(args.class_num)[labels]
                el2n = np.linalg.norm(preds, axis=-1)
                subset = np.random.choice(len(el2n), size=B, replace=False, p=el2n/np.sum(el2n))
            elif args.prune_method == 'el2n':
                args.logger.info(f'Training on {B} examples selected by EL2N')
                score = np.zeros(args.train_num)
                if args.dataset[:5] == 'cifar':
                    if (args.start_epoch == 0) and (args.local_interval == 200):
                        el2n_epoch = 20
                    elif epoch == 0:
                        el2n_epoch = args.local_interval
                    else:
                        el2n_epoch = epoch
                    if args.noisy_label_size > 0:
                        for subset_file in glob.glob(f'output_{args.dataset}/record/{args.dataset}_transform{idn_str}_noisy_{args.noisy_label_size:.1f}_resnet18_1.0_128_seed_{args.seed}_lr_0.1_mile_epochs_200_record/subset_epoch_{el2n_epoch}_0.pth'):
                            args.logger.info("=> loading stats from '{}'".format(subset_file))
                            stats = torch.load(subset_file)
                            score += stats['el2n']
                    else:
                        for subset_file in glob.glob(f'output_cifar10/output_record/{args.dataset}_transform_resnet18_1.0_128_seed_*_lr_0.1_record/subset_epoch_{el2n_epoch}_0.pth'):
                            args.logger.info("=> loading stats from '{}'".format(subset_file))
                            stats = torch.load(subset_file)
                            score += stats['el2n']
                elif args.dataset == 'cinic10':
                    if (args.start_epoch == 0) and (args.local_interval == 200):
                        el2n_epoch = 20
                    elif epoch == 0:
                        el2n_epoch = args.local_interval
                    else:
                        el2n_epoch = epoch
                    for subset_file in glob.glob(f'output_{args.dataset}/record/{args.dataset}_transform_resnet18_1.0_128_seed_*_lr_0.1_mile_epochs_200_record/subset_epoch_{el2n_epoch}_0.pth'):
                            args.logger.info("=> loading stats from '{}'".format(subset_file))
                            stats = torch.load(subset_file)
                            score += stats['el2n']
                else:
                    for subset_file in glob.glob(f'output_caltech/record/{args.dataset}_transform_resnet18_1.0_128_seed_*_lr_0.001_cnt_epochs_40_record/subset_epoch_{args.local_interval}_0.pth'):
                        args.logger.info("=> loading stats from '{}'".format(subset_file))
                        stats = torch.load(subset_file)
                        score += stats['el2n']
                subset = np.argsort(-score)
                subset = subset[:B]
                torch.save(subset, os.path.join(args.save_dir, f'subset_epoch_{epoch}.pth'))
            elif args.prune_method == 'forget':
                args.logger.info(f'Training on {B} examples selected by forgettability')
                score = np.zeros(args.train_num)
                if args.dataset[:5] == 'cifar':
                    if (args.start_epoch == 0) and (args.local_interval == 200):
                        forget_epoch = 200
                    elif epoch == 0:
                        forget_epoch = args.local_interval
                    else:
                        forget_epoch = epoch
                    if args.noisy_label_size > 0:
                        for subset_file in glob.glob(f'output_{args.dataset}/record/{args.dataset}_transform{idn_str}_noisy_{args.noisy_label_size:.1f}_resnet18_1.0_128_seed_{args.seed}_lr_0.1_mile_epochs_200_record/subset_epoch_{forget_epoch}_0.pth'):
                            args.logger.info("=> loading stats from '{}'".format(subset_file))
                            stats = torch.load(subset_file)
                            score += stats['forget']
                    else:
                        for subset_file in glob.glob(f'output_cifar10/output_record/{args.dataset}_transform_resnet18_1.0_128_seed_*_lr_0.1_record/subset_epoch_{forget_epoch}_0.pth'):
                            args.logger.info("=> loading stats from '{}'".format(subset_file))
                            stats = torch.load(subset_file)
                            score += stats['forget']
                elif args.dataset == 'cinic10':
                    if (args.start_epoch == 0) and (args.local_interval == 200):
                        forget_epoch = 200
                    elif epoch == 0:
                        forget_epoch = args.local_interval
                    else:
                        forget_epoch = epoch
                    for subset_file in glob.glob(f'output_{args.dataset}/record/{args.dataset}_transform_resnet18_1.0_128_seed_*_lr_0.1_mile_epochs_200_record/subset_epoch_{forget_epoch}_0.pth'):
                            args.logger.info("=> loading stats from '{}'".format(subset_file))
                            stats = torch.load(subset_file)
                            score += stats['forget']
                else:
                    for subset_file in glob.glob(f'output_caltech/record/{args.dataset}_transform_resnet18_1.0_128_seed_*_lr_0.001_cnt_epochs_40_record/subset_epoch_{args.local_interval}_0.pth'):
                        args.logger.info("=> loading stats from '{}'".format(subset_file))
                        stats = torch.load(subset_file)
                        score += stats['forget']
                subset = np.argsort(-score)
                subset = subset[:B]
                torch.save(subset, os.path.join(args.save_dir, f'subset_epoch_{epoch}.pth'))
            elif (args.prune_method == 'traject-most') and (epoch == args.start_epoch):
                args.logger.info(f'Training on {B} examples selected by most selected by trajectory')
                score = np.zeros(50000)
                for subset_file in glob.glob(f'output_local_traject_warmup_4/cifar10_transform_resnet18_ts_{train_size:.3f}_bs_128_sgd_lr_0.1_mile_epochs_200_seed_*_local_next_traject_{args.prune_size:.3f}_start_4_interval_{args.local_interval}/subset_epoch_*.pth'):
                    args.logger.info("=> loading subset from '{}'".format(subset_file))
                    subset = torch.load(subset_file)
                    score[subset] += 1
                subset = np.argsort(-score)
                subset = subset[:B]
                torch.save(subset, os.path.join(args.save_dir, f'subset_epoch_{epoch}.pth'))
            elif (args.prune_method == 'traject-avg'):
                args.logger.info(f'Selecting {B} examples by averaged trajectory')
                traject = []
                num_avg = 0
                if args.end_epoch == args.local_interval:
                    for subset_file in glob.glob(f'output_cifar10/output_compress/cifar10_transform_match_resnet18_ts_{train_size:.3f}_bs_128_sgd_lr_0.1_mile_epochs_200_seed_*_local_rewind_traject_{args.prune_size:.3f}_start_0_interval_{args.local_interval}_end_{args.local_interval}/p-y_epoch_0.pth'):
                        if num_avg == args.num_avg:
                            break
                        args.logger.info("=> loading gradient from '{}'".format(subset_file))
                        grad = torch.load(subset_file)
                        traject.append(grad)
                        num_avg += 1
                else:
                    for subset_file in glob.glob(f'output_cifar10/output_local_traject_warmup_{args.start_epoch}_single/cifar10_transform_resnet18_ts_{train_size:.3f}_bs_128_sgd_lr_0.1_mile_epochs_200_seed_*_local_next_traject_{args.prune_size:.3f}_start_{args.start_epoch}_interval_{args.local_interval}_every_{args.local_interval}/p-y_epoch_{epoch}.pth'):
                        if num_avg == args.num_avg:
                            break
                        args.logger.info("=> loading gradient from '{}'".format(subset_file))
                        grad = torch.load(subset_file)
                        traject.append(grad)
                        num_avg += 1
                traject = np.mean(np.stack(traject), axis=0)
                subset, subset_weight, _, _, _, _ = get_orders_and_weights(
                    B, traject, 'euclidean', y=labels, 
                    equal_num=args.per_class)
                torch.save(subset, os.path.join(args.save_dir, f'subset_epoch_{epoch}.pth'))
            elif (args.prune_method == 'traject-transfer'):
                args.logger.info(f'Training on {B} examples transferred from ResNet-18')
                if args.resume:
                    subset_file = args.resume
                else:
                    subset_file = f'output_cifar10/output_local_traject_warmup_{args.start_epoch}'
                    subset_file += f'_single' if (args.local_freq == args.local_interval) else ''
                    subset_file += f'/cifar10_transform_resnet18_ts_{train_size:.3f}_bs_128_sgd_lr_0.1_mile_epochs_200_seed_{args.seed}_local_next_traject_{args.prune_size:.3f}_start_4_interval_{args.local_interval}'
                    subset_file += f'_every_{args.local_freq}' if args.local_freq > 1 else ''
                subset_file += f'/subset_epoch_{epoch}.pth'
                args.logger.info("=> loading subset from '{}'".format(subset_file))
                subset = torch.load(subset_file)
            if args.strong_aug and (epoch == args.start_epoch):
                logger.info("Use strong augmentations")
                indexed_dataset = IndexedDataset(args.dataset, train=True, transform=strong_train_transform, noisy=args.noisy_label_size, seed=args.seed, args=args)
            indexed_subset = torch.utils.data.Subset(indexed_dataset, indices=subset)
            if args.add_strong_aug and (epoch == args.start_epoch):
                logger.info("Add strong augmentations")
                indexed_subset = ConcatDataset([indexed_subset, torch.utils.data.Subset(IndexedDataset(args.dataset, train=True, transform=strong_train_transform, noisy=args.noisy_label_size, seed=args.seed, args=args), indices=subset)])
            train_loader = DataLoader(
                indexed_subset,
                batch_size=args.batch_size, shuffle=True,
                num_workers=args.workers, pin_memory=True, drop_last=True)
            

        data_time[epoch], train_time[epoch], acc = train(args, train_loader, model, train_criterion, optimizer, epoch, weight)
        first_learned[acc>0] = np.minimum(first_learned[acc>0], np.ones(np.sum(acc, dtype=int)) * epoch)
        times_forgot[acc - learned < 0] += 1
        learned = acc

        torch.save({'forget': times_forgot, 'first_learned':first_learned}, os.path.join(args.save_dir, f'forget_learned.pth'))

        # evaluate on validation set
        if args.dataset == 'waterbirds':
            prec1, worst, loss = validate_by_group(args, val_loader, model, val_criterion)
        elif args.dataset == 'cmnist':
            loss, prec1, worst = evaluate(args, val_loader, model, val_criterion)
        else:
            prec1, loss, accs = validate(args, val_loader, model, val_criterion)
        # acc = get_acc(args, train_val_loader, model)
        # torch.save(acc, os.path.join(args.save_dir, f'acc_epoch_{epoch}.pth'))
        best_prec1 = max(prec1, best_prec1)
        test_acc[epoch], test_loss[epoch] = prec1, loss

        if args.lr_schedule == 'plateau':
            lr_scheduler_f.step(loss)
        else:
            lr_scheduler_f.step()

        ta, tl, _ = validate(args, train_val_loader, model, val_criterion)
        best_loss = min(tl, best_loss)
        train_acc[epoch], train_loss[epoch] = ta, tl

        if end_epoch == args.epochs:
            args.writer.add_scalar('test/1.test_acc_by_epoch', prec1, epoch)
            args.writer.add_scalar('test/2.test_loss_by_epoch', loss, epoch)
            args.writer.add_scalar('train/1.train_acc_by_epoch', ta, epoch)
            args.writer.add_scalar('train/2.train_loss_by_epoch', tl, epoch)

            args.writer.add_scalar('test/3.test_acc_by_time', prec1, int(np.sum(train_time)+np.sum(data_time)))
            args.writer.add_scalar('test/4.test_loss_by_time', loss, int(np.sum(train_time)+np.sum(data_time)))
            args.writer.add_scalar('train/3.train_acc_by_time', ta, int(np.sum(train_time)+np.sum(data_time)))
            args.writer.add_scalar('train/4.train_loss_by_time', tl, int(np.sum(train_time)+np.sum(data_time)))

            if args.dataset in ['waterbirds', 'cmnist']:
                args.writer.add_scalar('test/5.test_worst_acc_by_epoch', worst, epoch)

        args.logger.info(f'train_size: {train_size}, epoch: {epoch}, prec1: {prec1}, loss: {tl:.3f}, '
            f'lr: {args.lr:.3f}, best_prec1: {best_prec1}, best_loss: {best_loss:.3f}')

        epoch += 1

    if end_epoch == args.epochs:
        save_checkpoint({'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': lr_scheduler_f.state_dict(),
            'best_prec1': best_prec1,
            'best_loss': best_loss,
            }, filename=os.path.join(args.save_dir, 'final.pth'))
        return
    else:
        raise RuntimeError

if __name__ == '__main__':
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    main(args)
