import argparse
import datetime

import torch.backends.cudnn as cudnn
import json
import yaml
from pathlib import Path

from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler

from timm.utils import NativeScaler
from lib.datasets import build_dataset
from engine import *
from lib.samplers import RASampler

from lib.config import cfg, update_config_from_file
from model.vision_transformer_timm import VisionTransformerSepQKV

import model as models
from timm.models import load_checkpoint

#SLUMRM
from mmcv.runner import get_dist_info, init_dist

import os
from timm.utils.clip_grad import dispatch_clip_grad
from collections import OrderedDict
from lib import utils


class SPT_scaler:
    state_dict_key = "amp_scaler"

    def __init__(self, fully_ft_list=None, unstructured_ft_dict=None, freezed_state_dict=None):
        self._scaler = torch.cuda.amp.GradScaler()
        self.fully_ft_list = fully_ft_list
        self.unstructured_ft_dict = unstructured_ft_dict
        self.freezed_state_dict = freezed_state_dict

    def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', model=None, create_graph=False):

        assert hasattr(self, 'freezed_state_dict')
        self._scaler.scale(loss).backward(create_graph=create_graph)

        # parameters = self.add_mask(parameters)
        # assert sum(1 for _ in parameters.clone()) == len(key_list)
        # for name, p in model.module.named_parameters():
        #     if p.grad is not None and name in self.grad_mask.keys():
        #
        #         grad_mask = self.grad_mask[name].to(p.device)
        #         p.grad = torch.where(grad_mask==1, p.grad, grad_mask)

        # Testing whether the parameters are really freezed
        # tmp = torch.cat([model.state_dict()[key].flatten() for key in model.state_dict().keys() if key in self.fully_ft_list])
        # tmp1 = torch.cat([model.state_dict()[key].flatten() for key in model.state_dict().keys() if key in self.partial_ft_dict])
        # tmp2 = torch.cat([self.freezed_state_dict[key].flatten() for key in self.freezed_state_dict.keys()])
        # print(tmp.sum(), tmp1.sum(), tmp2.sum())

        if clip_grad is not None:
            self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
            dispatch_clip_grad(model.parameters(), clip_grad, mode=clip_mode)

        self._scaler.step(optimizer)
        self._scaler.update()

        fully_tuned = []
        partially_tuned = []
        freezed = []

        tmp_state_dict = {}
        for key in model.state_dict().keys():

            if key in self.fully_ft_list:
                # Fully fine-tuned
                fully_tuned.append(key)
                tmp_state_dict[key] = model.state_dict()[key]

            elif key in self.unstructured_ft_dict:
                # Structurally fine-tuned, we reload the tuned parameters which is more efficient than masking gradient
                grad_mask = self.unstructured_ft_dict[key].cuda()
                tmp_state_dict[key] = torch.where(grad_mask == 1, model.state_dict()[key],
                                                  self.freezed_state_dict[key].cuda().detach())
                partially_tuned.append(key)

            else:
                # Fully freezed
                tmp_state_dict[key] = self.freezed_state_dict[key].cuda().detach()
                freezed.append(key)

        model.load_state_dict(tmp_state_dict)

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


def get_args_parser():
    parser = argparse.ArgumentParser('AutoFormer training and evaluation script', add_help=False)
    parser.add_argument('--batch-size', default=64, type=int)
    parser.add_argument('--epochs', default=300, type=int)
    # config file
    parser.add_argument('--cfg',help='experiment configure file name',required=True,type=str)

    # custom parameters
    parser.add_argument('--platform', default='pai', type=str, choices=['itp', 'pai', 'aml'],
                        help='Name of model to train')
    parser.add_argument('--teacher_model', default='', type=str,
                        help='Name of teacher model to train')
    parser.add_argument('--relative_position', action='store_true')
    parser.add_argument('--gp', action='store_true')
    parser.add_argument('--change_qkv', action='store_true')
    parser.add_argument('--max_relative_position', type=int, default=14, help='max distance in relative position embedding')

    # Model parameters
    parser.add_argument('--model', default='', type=str, metavar='MODEL',
                        help='Name of model to train')
    # AutoFormer config
    parser.add_argument('--mode', type=str, default='super', choices=['super', 'vp','retrain','search'], help='mode of AutoFormer')
    parser.add_argument('--input-size', default=224, type=int)
    parser.add_argument('--patch_size', default=16, type=int)

    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')
    parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                        help='Drop block rate (default: None)')

    parser.add_argument('--model-ema', action='store_true')
    parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
    # parser.set_defaults(model_ema=True)
    parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
    parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
    parser.add_argument('--rpe_type', type=str, default='bias', choices=['bias', 'direct'])
    parser.add_argument('--post_norm', action='store_true')
    parser.add_argument('--no_abs_pos', action='store_true')

    # Optimizer parameters
    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                        help='Optimizer (default: "adamw"')
    parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                        help='Optimizer Epsilon (default: 1e-8)')
    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                        help='Optimizer Betas (default: None, use opt default)')
    parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')

    # Learning rate schedule parameters
    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
                        help='LR scheduler (default: "cosine"')
    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
                        help='learning rate (default: 5e-4)')
    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                        help='learning rate noise on/off epoch percentages')
    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                        help='learning rate noise limit percent (default: 0.67)')
    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                        help='learning rate noise std-dev (default: 1.0)')
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
                        help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    parser.add_argument('--lr-power', type=float, default=1.0,
                        help='power of the polynomial lr scheduler')

    parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
                        help='epoch interval to decay LR')
    parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                        help='patience epochs for Plateau LR scheduler (default: 10')
    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                        help='LR decay rate (default: 0.1)')

    # Augmentation parameters
    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                        help='Color jitter factor (default: 0.4)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + \
                             "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='bicubic',
                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

    parser.add_argument('--repeated-aug', action='store_true')
    # parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')


    # parser.set_defaults(repeated_aug=True)

    # * Random Erase params
    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')

    # * Mixup params
    parser.add_argument('--mixup', type=float, default=0.8,
                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
    parser.add_argument('--cutmix', type=float, default=1.0,
                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup-prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup-mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

    # Dataset parameters
    parser.add_argument('--data-path', default='./data/imagenet/', type=str,
                        help='dataset path')
     # parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19', 'EVO_IMNET'],
    #                     type=str, help='Image Net dataset path')
    parser.add_argument('--data-set', default='IMNET', type=str, help='Image Net dataset path')
    parser.add_argument('--inat-category', default='name',
                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
                        type=str, help='semantic granularity')

    parser.add_argument('--output_dir', default='./',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
    parser.add_argument('--pin-mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
                        help='')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')

    parser.add_argument('--amp', action='store_true')
    parser.add_argument('--no-amp', action='store_false', dest='amp')
    # parser.set_defaults(amp=True)

    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')

    parser.add_argument('--no_aug', action='store_true')
    parser.add_argument('--val_interval', default=1, type=int, help='validataion interval')
    parser.add_argument('--inception',action='store_true')
    parser.add_argument('--direct_resize',action='store_true')

    # Our params
    parser.add_argument('--exp-name', default='', type=str)
    parser.add_argument('--freeze_stage', action='store_true')
    parser.add_argument('--sensitivity_path', default='', type=str,)
    parser.add_argument('--scaler', default='naive', type=str,)
    parser.add_argument('--low_rank_dim', default=8, type=int,)
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--save_best', action='store_true')
    parser.add_argument('--freeze_others', default='True', type=str)
    parser.add_argument('--freeze_selected', action='store_true')
    parser.add_argument('--block', type=str, default='BlockSPT')
    parser.add_argument('--get_sensitivity', action='store_true')
    parser.add_argument('--structured_vector', default='True', type=str)

    return parser


def main(args):

    # utils.init_distributed_mode(args)
    update_config_from_file(args.cfg)
    if args.launcher == 'none':
        args.distributed = False
    else:
        args.distributed = True
        init_dist(launcher=args.launcher)
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()

    print(args)
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)
    cudnn.benchmark = True
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args,)
    dataset_val, _ = build_dataset(is_train=False, args=args,)

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, batch_size=int(2 * args.batch_size),
        sampler=sampler_val, num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=False
    )

    print(f"{args.data_set} dataset, train: {len(dataset_train)}, evaluation: {len(dataset_val)}")

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    print('mixup_active',mixup_active)
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)

    print(f"Creating SuperVisionTransformer")
    print(cfg)

    if args.get_sensitivity:
        model = models.__dict__[cfg.MODEL_NAME](img_size=args.input_size,
                                                drop_rate=args.drop,
                                                drop_path_rate=args.drop_path,
                                                freeze_backbone=args.freeze_stage,
                                                num_classes=args.nb_classes
                                                )

    else:
        if args.scaler == 'spt_scaler':
            param_info = torch.load(args.sensitivity_path, map_location='cpu')
            tuned_vectors = param_info['tuned_vectors']
            tuned_matrices = param_info['tuned_matrices']

            print('Both structured and unstructured tuning', )
            print('Sensitive matrices: ', tuned_matrices)
            print('Sensitive vectors: ', tuned_vectors)
            print('Total params: {0:.2f} M'.format(param_info['params'].item()))

        model = models.__dict__[cfg.MODEL_NAME](img_size=args.input_size,
                                                drop_rate=args.drop,
                                                drop_path_rate=args.drop_path,
                                                freeze_backbone=args.freeze_stage,
                                                structured_list=tuned_matrices,
                                                tuned_vectors=tuned_vectors,
                                                low_rank_dim=args.low_rank_dim,
                                                freeze_others=args.freeze_others,
                                                block=args.block,
                                                num_classes=args.nb_classes
                                                )

    train_engine = train_one_epoch
    test_engine = evaluate

    if args.resume:
        if '.pth' in args.resume:

            if args.resume.startswith('mae'):
                state_dict = torch.load(args.resume, map_location='cpu')['model']
                new_dict = OrderedDict()
                for name in state_dict.keys():
                    if 'attn.qkv.' in name:
                        new_dict[name.replace('qkv', 'q')] = state_dict[name][:state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'k')] = state_dict[name][state_dict[name].shape[0] // 3:-state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'v')] = state_dict[name][-state_dict[name].shape[0] // 3:]
                    else:
                        new_dict[name] = state_dict[name]

                args.suffix = 'mae'
                msg = model.load_state_dict(new_dict, strict=False)
                print('Resuming from MAE model: ', msg)

            elif args.resume.startswith('linear'):
                state_dict = torch.load(args.resume, map_location='cpu')['state_dict']
                new_dict = OrderedDict()
                for name in state_dict.keys():
                    if 'attn.qkv.' in name:
                        new_dict[name.replace('qkv', 'q').split('module.')[1]] = state_dict[name][:state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'k').split('module.')[1]] = state_dict[name][state_dict[name].shape[0] // 3:-state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'v').split('module.')[1]] = state_dict[name][-state_dict[name].shape[0] // 3:]
                    elif 'head.' in name:
                        continue
                    else:
                        new_dict[name.split('module.')[1]] = state_dict[name]

                args.suffix = 'moco'
                msg = model.load_state_dict(new_dict, strict=False)
                print('Resuming from MoCo model: ', msg)

            elif args.resume.startswith('vit-b-300ep.pth'):
                state_dict = torch.load(args.resume, map_location='cpu')['state_dict']
                new_dict = OrderedDict()
                for name in state_dict.keys():
                    if name.startswith('module.base_encoder'):
                        new_name = name[len("module.base_encoder."):]
                        if 'attn.qkv.' in name:
                            new_dict[new_name.replace('qkv', 'q')] = state_dict[name][:state_dict[name].shape[0] // 3]
                            new_dict[new_name.replace('qkv', 'k')] = state_dict[name][state_dict[name].shape[0] // 3:-state_dict[name].shape[0] // 3]
                            new_dict[new_name.replace('qkv', 'v')] = state_dict[name][-state_dict[name].shape[0] // 3:]
                        elif 'head.' in name:
                            continue
                        else:
                            new_dict[new_name] = state_dict[name]

                args.suffix = 'moco'
                msg = model.load_state_dict(new_dict, strict=False)
                print('Resuming from MoCo model: ', msg)

            else:

                if args.nb_classes != model.head.weight.shape[0]:
                    model.reset_classifier(args.nb_classes)
                incompatible_keys = load_checkpoint(model, args.resume,strict=False)
                print('Resuming from .pth model: ', incompatible_keys)

        else:
            load_checkpoint(model, args.resume)
            if args.nb_classes != model.head.weight.shape[0]:
                model.reset_classifier(args.nb_classes)

            args.suffix = 'vit'

    model.to(device)
    model_ema = None
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.get_sensitivity:
        get_sensitivity(
            model, criterion, data_loader_train, device, epoch=0,
            amp=args.amp, dataset=args.data_set, nb_classes=args.nb_classes,
            suffix=args.suffix, structured_vector=args.structured_vector
        )
        return

    optimizer = utils.build_optimizer(args, model_without_ddp)
    if args.scaler == 'naive':
        loss_scaler = NativeScaler()

    elif args.scaler == 'spt_scaler':
        assert args.resume
        freezed_state_dict = {key: model.state_dict()[key].clone().detach() for key in model.state_dict().keys()}

        fully_fine_tuned = []
        fully_fine_tuned.extend(tuned_vectors)
        fully_fine_tuned.extend(['head.weight', 'head.bias'])
        fully_fine_tuned.extend([key for key in model.state_dict().keys() if 'structured' in key])

        # Setting up unstructured tuning
        unstructured_name_shapes = param_info['unstructured_name_shapes']
        unstructured_indexes = param_info['unstructured_indexes']
        tmp_large_tensor = torch.cat([model.state_dict()[key].flatten() for key in unstructured_name_shapes.keys()])

        # Generating the gradient mask
        grad_mask = torch.zeros_like(tmp_large_tensor)
        del tmp_large_tensor
        grad_mask[unstructured_indexes] = 1.
        grad_mask = grad_mask.split([np.cumprod(list(shape))[-1] for shape in unstructured_name_shapes.values()])
        grad_mask_dict = {}
        for i, key in enumerate(unstructured_name_shapes.keys()):
            grad_mask_dict[key] = grad_mask[i].view(unstructured_name_shapes[key])

        loss_scaler = SPT_scaler(fully_ft_list=fully_fine_tuned,
                                 unstructured_ft_dict=grad_mask_dict,
                                 freezed_state_dict=freezed_state_dict)

        # Frozen parameters
        frozen_keys = list(set(model.state_dict().keys()) - set(fully_fine_tuned) - set(grad_mask_dict.keys()))
        model.freeze_selected_params(frozen_keys)
    else:
        raise NotImplementedError

    lr_scheduler, _ = create_scheduler(args, optimizer)

    output_dir = Path(args.output_dir)

    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    # save config for later experiments
    with open(output_dir / "config.yaml", 'w') as f:
        f.write(args_text)

    if args.eval:
        test_stats = test_engine(data_loader_val, model, device, amp=args.amp)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    print("Start training")

    # return
    start_time = time.time()
    max_accuracy = 0.0

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_engine(
            model, criterion, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, model_ema, mixup_fn,
            amp=args.amp,scaler=args.scaler
        )

        lr_scheduler.step(epoch)

        if epoch % args.val_interval == 0 or epoch >= args.epochs-10:
            test_stats = test_engine(data_loader_val, model, device, amp=args.amp)
            print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
            max_accuracy = max(max_accuracy, test_stats["acc1"])
            print(
                f"[{args.exp_name}] Max accuracy on the {args.data_set} dataset {len(dataset_val)} with ({args.opt}, {args.lr}, {args.weight_decay}), {max_accuracy:.2f}%")

            # if max_accuracy == test_stats["acc1"] and args.save_best:
            #     if args.output_dir:
            #         checkpoint_paths = [output_dir / 'best_checkpoint.pth']
            #         for checkpoint_path in checkpoint_paths:
            #             utils.save_on_master({
            #                 'model': model_without_ddp.state_dict(),
            #                 # 'optimizer': optimizer.state_dict(),
            #                 # 'lr_scheduler': lr_scheduler.state_dict(),
            #                 'epoch': epoch,
            #                 # 'scaler': loss_scaler.state_dict(),
            #                 'args': args,
            #             }, checkpoint_path)

            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                        **{f'test_{k}': v for k, v in test_stats.items()},
                        'epoch': epoch,
                        'n_parameters': param_info['params'].item()}

            if args.output_dir and utils.is_main_process():
                with (output_dir / "log.txt").open("a") as f:
                    f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    parser = argparse.ArgumentParser('AutoFormer training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
