import argparse
import copy
import importlib
import json
import logging
import os
import time
import platform
from collections import OrderedDict
from collections import defaultdict
from contextlib import suppress
from datetime import datetime
from functools import partial
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import math
# import seaborn as sns
import scipy.sparse
import scipy.sparse.csgraph
from PIL import Image
import pickle
import collections
import re

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDP

from timm import utils
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScaler
from timm.utils import save_hardware
from timm.results import get_difference_results
from timm.results import get_cluster_size_frequency, get_n_skipped_layer_cluster, get_n_cluster_per_layer
from timm.results import get_cluster_size_frequency_histogram
from timm.results import get_explanation_by_perturbation, get_explanation_by_gradient


try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as ApexDDP
    from apex.parallel import convert_syncbn_model
    has_apex = True
except ImportError:
    has_apex = False


try:
    import wandb
    has_wandb = True
except ImportError:
    has_wandb = False

try:
    from functorch.compile import memory_efficient_fusion
    has_functorch = True
except ImportError as e:
    has_functorch = False


has_compile = hasattr(torch, 'compile')


_logger = logging.getLogger('train')

# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

# Dataset parameters
group = parser.add_argument_group('Dataset parameters')
# Keep this argument outside the dataset group because it is positional.
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
                    help='path to dataset (positional is *deprecated*, use --data-dir)')
parser.add_argument('--data-dir', metavar='DIR',
                    help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
                    help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',
                   help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
                   help='dataset validation split (default: validation)')
parser.add_argument('--train-num-samples', default=None, type=int,
                    metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
parser.add_argument('--val-num-samples', default=None, type=int,
                    metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
group.add_argument('--dataset-download', action='store_true', default=False,
                   help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
                   help='path to class to idx mapping file (default: "")')
group.add_argument('--input-img-mode', default=None, type=str,
                   help='Dataset image conversion mode for input images.')
group.add_argument('--input-key', default=None, type=str,
                   help='Dataset key for input images.')
group.add_argument('--target-key', default=None, type=str,
                   help='Dataset key for target labels.')
group.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
                   help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')

# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
                   help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
                   help='Start with pretrained version of specified network (if avail)')
group.add_argument('--pretrained-path', default=None, type=str,
                   help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                   help='Load this checkpoint into model after initialization (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
                   help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
                   help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=None, metavar='N',
                   help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',
                   help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
                   help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
                   help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
                   metavar='N N N',
                   help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=None, type=float,
                   metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                   help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                   help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
                   help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
                   help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
                   help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
                   help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
                   help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',
                   help='The number of steps to accumulate gradients (default: 1)')
group.add_argument('--grad-checkpointing', action='store_true', default=False,
                   help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
                   help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
group.add_argument('--head-init-scale', default=None, type=float,
                   help='Head initialization scale')
group.add_argument('--head-init-bias', default=None, type=float,
                   help='Head initialization bias value')
group.add_argument('--torchcompile-mode', type=str, default=None,
                    help="torch.compile mode (default: None).")

# scripting / codegen
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
                             help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
                             help="Enable compilation w/ specified backend (default: inductor).")

# Device & distributed
group = parser.add_argument_group('Device parameters')
group.add_argument('--device', default='cuda', type=str,
                    help="Device (accelerator) to use.")
group.add_argument('--amp', action='store_true', default=False,
                   help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
                   help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
                   help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
                   help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
                   help='torch.cuda.synchronize() end of each step')
group.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
                    help="Python imports for device backend modules.")

# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
                   help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
                   help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                   help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
                   help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5,
                   help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                   help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',
                   help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
                   help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)

# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
                   help='LR scheduler (default: "cosine"')
group.add_argument('--sched-on-updates', action='store_true', default=False,
                   help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',
                   help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
                   help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
                   help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
                   help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                   help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                   help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                   help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
                   help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
                   help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
                   help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
                   help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
                   help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
                   help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N',
                   help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
                   help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
                   help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
                   help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
                   help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
                   help='epochs to warmup LR, if scheduler supports')
group.add_argument('--warmup-prefix', action='store_true', default=False,
                   help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
                   help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                   help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                   help='LR decay rate (default: 0.1)')

# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
                   help='Disable all training augmentation, override other train aug args')
group.add_argument('--train-crop-mode', type=str, default=None,
                   help='Crop-mode in train'),
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
                   help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
                   help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
                   help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.,
                   help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                   help='Color jitter factor (default: 0.4)')
group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',
                   help='Probability of applying any color jitter.')
group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',
                   help='Probability of applying random grayscale conversion.')
group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',
                   help='Probability of applying gaussian blur.')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
                   help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
                   help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
                   help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
                   help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,
                   help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-sum', action='store_true', default=False,
                   help='Sum over classes when using BCE loss.')
group.add_argument('--bce-target-thresh', type=float, default=None,
                   help='Threshold for binarizing softened BCE targets (default: None, disabled).')
group.add_argument('--bce-pos-weight', type=float, default=None,
                   help='Positive weighting for BCE loss.')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
                   help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',
                   help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
                   help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
                   help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,
                   help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,
                   help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                   help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
                   help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
                   help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
                   help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                   help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,
                   help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
                   help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                   help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
                   help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
                   help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                   help='Drop block rate (default: None)')

# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
                   help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
                   help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
                   help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
                   help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
                   help='Enable separate BN layers per augmentation split.')

# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
                   help='Enable tracking moving average of model weights.')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
                   help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
                   help='Decay factor for model weights moving average (default: 0.9998)')
group.add_argument('--model-ema-warmup', action='store_true',
                   help='Enable warmup for model EMA decay.')

# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
                   help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
                   help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
                   help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
                   help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=1, metavar='N',
                   help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
                   help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,
                   help='save images of input batches every log interval for debugging')
group.add_argument('--pin-mem', action='store_true', default=False,
                   help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
                   help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
                   help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
                   help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
                   help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
                   help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
                   help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
                   help='log training and validation metrics to wandb')
group.add_argument('--wandb-tags', default=[], type=str, nargs='+',
                   help='wandb tags')
group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID',
                   help='If resuming a run, the id of the run in wandb')

# Logging
group = parser.add_argument_group('Logging')
group.add_argument('--vis_dir_prefix', default='base',
                   help='save directory prefix for experiment folder')
group.add_argument('--time_stamp', default='',
                   help='time_stamp for logging folder')

# clustering
group = parser.add_argument_group('Clustering features')
group.add_argument('--n_accum_cluster', type=int, default=20,
                   help='Unit : args.batch_size * args.grad_accum_steps')
group.add_argument('--layer_sampling_dim', type=int, default=3,
                   help='Clustering over sampled partial layer for reducing memory usage')
group.add_argument('--token_sampling_dim', type=int, default=3,
                   help='Clustering over sampled partial token for reducing memory usage')
group.add_argument('--n_bin', type=int, default=100,
                   help='Binning the feature for calculating n rank based difference between features')
group.add_argument('--n_submodel_clustering', type=int, default=1,
                   help='N times clutering over sampled submodel with layer_sampling_dim and token_smapling_dim')
group.add_argument('--block-pattern', type=str, default='blocks\.\d+',
                   help='Block pattern (default: blocks\.\d+)')

# Learning with clustering
group = parser.add_argument_group('Learning with cluster')
group.add_argument('--use_anchorwise_feature', action='store_true', default=False,
                   help='turn on learning with anchorwise_feature')
group.add_argument('--use_gathering_feature', action='store_true', default=False,
                   help='turn on learning with gathering_feature')
group.add_argument('--use_layerwise_contiuous_feature', action='store_true', default=False,
                   help='turn on learning with layerwise_contiuous_feature ')
group.add_argument('--use_identity_feature', action='store_true', default=False,
                   help='turn on learning with use_identity_feature; structure loss')
group.add_argument('--use_anchoring_feature', action='store_true', default=False,
                   help='turn on learning with anchoring_feature')
group.add_argument('--scaling_threshold', type=int, default=1,
                   help='scaling threshold which multiply to the number of samples')

group.add_argument('--untrained_visualization', action='store_true', default=False,
                   help='visualizing the untrained model')
group.add_argument('--selected_layer_start_idx_list', nargs='+', type=int, help='List of integers')
group.add_argument('--selected_token_start_idx_list', nargs='+', type=int, help='List of integers')

# visualization
group = parser.add_argument_group('Visualization results')
group.add_argument('--sqrt_n_vis_samples', type=int, default=50,
                   help='plot for n**2 samples (default: 50)')
group.add_argument('--no_similarity_visualization', action='store_true', default=False,
                   help='turn off similarity_visualization')
group.add_argument('--no_similarity_statistic', action='store_true', default=False,
                   help='turn off similarity_visualization')
group.add_argument('--no_overall_features_visualization', action='store_true', default=False,
                   help='turn off overall_features_visualization')
group.add_argument('--no_average_features_visualization', action='store_true', default=False,
                   help='turn off average_features_visualization')
group.add_argument('--no_shifted_features_visualization', action='store_true', default=False,
                   help='turn off features_visualization with substracting average_features')
group.add_argument('--no_features_visualization', action='store_true', default=False,
                   help='turn off features_visualization')
group.add_argument('--n_visualization_per_cluster', type=int, default=3,
                   help='plot for n_visualization_per_cluster; n for each size (default: 3)')
group.add_argument('--cluster-size-frequency-histogram', action='store_true', default=False,
                   help='cluster-size-frequency-histogram')
# explanation (SAG)
# [CVPR 2024] Comparing the Decision-Making Mechanisms by Transformers and CNNs via Explanation Methods
group = parser.add_argument_group('Explanation results')
group.add_argument('--perturbation-base-model-explanation', action='store_true', default=False,
                   help='cluster-size-frequency-histogram')
group.add_argument('--ups', type=int, default=30,
                   help='ratio of generate low resolution mask and upsample (default: 30)')
group.add_argument('--prob-thresh', type=float, default=0.9,
                   help='prob factor, prob_thresh * full_image_probability (default: 0.9)')
group.add_argument('--numCategories', type=int, default=1,
                   help='number of categories (default: 1)')
group.add_argument('--node-prob-thresh', type=int, default=40,
                   help='minimum score threshold to expand a node in the sag (default: 40)')
group.add_argument('--beam-width', type=int, default=3,
                   help='beam width, suggested values [3,5,10,15] (default: 3)')
group.add_argument('--max-num-roots', type=int, default=10,
                   help='upper limit on number of roots obtained via search, suggested values [10,20,30] (default: 10)')
group.add_argument('--overlap-thresh', type=int, default=1,
                   help='number of patches allowed to overlap in roots, suggested values [0,1,2] (default: 1)')
group.add_argument('--numSuccessors', type=int, default=15,
                   help='should be greater or equal to beam_width, q hyperparam in the paper(default: 15)')
group.add_argument('--num-roots-sag', type=int, default=3,
                   help='max number of roots to be displayed in the sag (default: 3)')
group.add_argument('--maxRootSize', type=int, default=49,
                   help='max number of patches allowed for a root (default: 49)')
# gradient base explanation (e.g. gradCAM); average for multi block version
group.add_argument('--gradient-base-model-explanation', type=str, default='',
                   help='select method : gradcam, hirescam, scorecam, gradcam++, ablationcam, xgradcam, eigencam, eigengradcam, layercam, fullgrad, fem, gradcamelementwise, kpcacam, shapleycam, finercam')
group.add_argument('--n-batch-explanation', type=int, default=2,
                   help='number of samples to visualization (default: 2)')
group.add_argument('--gradient-base-model-explanation-statistic', action='store_true', default=False,
                   help='statistic about model explanation')

# statistic
group = parser.add_argument_group('Statistic results')
group.add_argument('--cluster-size-frequency', action='store_true', default=False,
                   help='turn on cluster-size-frequency')
group.add_argument('--n-skipped-layer-cluster', action='store_true', default=False,
                   help='turn on n-skipped-layer-cluster')
group.add_argument('--n-cluster-per-layer', action='store_true', default=False,
                   help='turn on n-cluster-per-layer')


def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text


def get_cluster(layerwise_batch_features_list, args, n_samples, n_per_bin):
    clusters = []
    diff_matrixs = []
    threshold_diff_matrixs = []
    binned_features_list = []
    labels_list = []
    for layerwise_batch_features in layerwise_batch_features_list:
        layerwise_batch_features = torch.cat(layerwise_batch_features, dim=0)
        sorted_node_feature, sorted_node_indices = torch.sort(layerwise_batch_features, dim=0)
        bins = torch.arange(args.n_bin).repeat_interleave(n_per_bin).view(-1, 1).float()
        bins = bins.to(layerwise_batch_features.device)
        bins = bins.expand_as(sorted_node_feature)
        binned_features = torch.zeros_like(sorted_node_feature)
        binned_features.scatter_(0, sorted_node_indices, bins)
        binned_features = binned_features.T
        binned_features_list.append(binned_features)
        diff_matrix = torch.cdist(binned_features, binned_features, p=1)
        diff_matrix = diff_matrix.long().cpu()
        diff_matrixs.append(diff_matrix)

        threshold_diff_matrix = diff_matrix < n_samples * args.scaling_threshold # diff_threshold = n_samples
        threshold_diff_matrixs.append(threshold_diff_matrix)
        # clustering with connected components
        sparse_matrix = scipy.sparse.csr_matrix(threshold_diff_matrix) # make sparse matrix
        num_clusters, labels = scipy.sparse.csgraph.connected_components(sparse_matrix, directed=False)
        labels_list.append(labels)
        cur_group_cluster = []
        for i in range(num_clusters):
            cur_cluster = np.where(labels == i)[0].tolist()
            cur_group_cluster.append(cur_cluster)
            if 1 < len(cur_cluster): # 1 = no distillation
                cur_group_cluster.append(cur_cluster)
        clusters.append(cur_group_cluster)
    return clusters, diff_matrixs, threshold_diff_matrixs, binned_features_list, labels_list


def main():
    utils.setup_default_logging()
    args, args_text = _parse_args()

    if args.untrained_visualization:
        print('Plot not trained figure')

    if args.device_modules:
        for module in args.device_modules:
            importlib.import_module(module)

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True

    args.prefetcher = not args.no_prefetcher
    args.grad_accum_steps = max(1, args.grad_accum_steps)
    device = utils.init_distributed_device(args)
    if args.distributed:
        _logger.info(
            'Training in distributed mode with multiple processes, 1 device per process.'
            f'Process {args.rank}, total {args.world_size}, device {args.device}.')
    else:
        _logger.info(f'Training with a single process on 1 device ({args.device}).')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    amp_dtype = torch.float16
    if args.amp:
        if args.amp_impl == 'apex':
            assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
            use_amp = 'apex'
            assert args.amp_dtype == 'float16'
        else:
            use_amp = 'native'
            assert args.amp_dtype in ('float16', 'bfloat16')
        if args.amp_dtype == 'bfloat16':
            amp_dtype = torch.bfloat16

    utils.random_seed(args.seed, args.rank)

    if args.fuser:
        utils.set_jit_fuser(args.fuser)
    if args.fast_norm:
        set_fast_norm()

    in_chans = 3
    if args.in_chans is not None:
        in_chans = args.in_chans
    elif args.input_size is not None:
        in_chans = args.input_size[0]

    factory_kwargs = {}
    if args.pretrained_path:
        # merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
        factory_kwargs['pretrained_cfg_overlay'] = dict(
            file=args.pretrained_path,
            num_classes=-1,  # force head adaptation
        )

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        in_chans=in_chans,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        **factory_kwargs,
        **args.model_kwargs,
    )
    if args.head_init_scale is not None:
        with torch.no_grad():
            model.get_classifier().weight.mul_(args.head_init_scale)
            model.get_classifier().bias.mul_(args.head_init_scale)
    if args.head_init_bias is not None:
        nn.init.constant_(model.get_classifier().bias, args.head_init_bias)

    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.grad_checkpointing:
        model.set_grad_checkpointing(enable=True)

    if utils.is_primary(args):
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')

    data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    if args.selected_layer_start_idx_list:
        assert len(args.selected_layer_start_idx_list) == args.n_submodel_clustering
    if args.selected_token_start_idx_list:
        assert len(args.selected_token_start_idx_list) == args.n_submodel_clustering

    activations = {}
    def hook_fn(module, input, output):
        activations[module] = output

        assert args.distributed == False, 'hook function should be updated for distributed settings'
        # activations[(rank, module)] = output.detach().float()

    pattern = re.compile(args.block_pattern) 
    target_layers = []
    for name, module in model.named_modules():
        if pattern.fullmatch(name):
            module.register_forward_hook(hook_fn)
            target_layers.append(module)

    # move model to GPU, enable channels last layout if set
    model.to(device=device)
    if args.channels_last:
        model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        args.dist_bn = ''  # disable dist_bn when sync BN active
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN used with Apex AMP
            # WARNING this won't currently work with models using BatchNormAct2d
            model = convert_syncbn_model(model)
        else:
            model = convert_sync_batchnorm(model)
        if utils.is_primary(args):
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not args.torchcompile
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    if not args.lr:
        global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
        batch_ratio = global_batch_size / args.lr_base_size
        if not args.lr_base_scale:
            on = args.opt.lower()
            args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
        if args.lr_base_scale == 'sqrt':
            batch_ratio = batch_ratio ** 0.5
        args.lr = args.lr_base * batch_ratio
        if utils.is_primary(args):
            _logger.info(
                f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
                f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')

    optimizer = create_optimizer_v2(
        model,
        **optimizer_kwargs(cfg=args),
        **args.opt_kwargs,
    )
    if utils.is_primary(args):
        defaults = copy.deepcopy(optimizer.defaults)
        defaults['weight_decay'] = args.weight_decay  # this isn't stored in optimizer.defaults
        defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()])
        logging.info(
            f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}'
        )

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        assert device.type == 'cuda'
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if utils.is_primary(args):
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
        if device.type in ('cuda',) and amp_dtype == torch.float16:
            # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
            loss_scaler = NativeScaler(device=device.type)
        if utils.is_primary(args):
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if utils.is_primary(args):
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume and not args.untrained_visualization:
        resume_epoch = resume_checkpoint(
            model,
            args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=utils.is_primary(args),
        )

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
        model_ema = utils.ModelEmaV3(
            model,
            decay=args.model_ema_decay,
            use_warmup=args.model_ema_warmup,
            device='cpu' if args.model_ema_force_cpu else None,
        )
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)
        if args.torchcompile:
            model_ema = torch.compile(model_ema, backend=args.torchcompile)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp == 'apex':
            # Apex DDP preferred unless native amp is activated
            if utils.is_primary(args):
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if utils.is_primary(args):
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP

    if args.torchcompile:
        # torch compile should be done after DDP
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
        model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)

    # create the train and eval datasets
    if args.data and not args.data_dir:
        args.data_dir = args.data
    if args.input_img_mode is None:
        input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
    else:
        input_img_mode = args.input_img_mode

    dataset_train = create_dataset(
        args.dataset,
        root=args.data_dir,
        split=args.train_split,
        is_training=True,
        class_map=args.class_map,
        download=args.dataset_download,
        batch_size=args.batch_size,
        seed=args.seed,
        repeats=args.epoch_repeats,
        input_img_mode=input_img_mode,
        input_key=args.input_key,
        target_key=args.target_key,
        num_samples=args.train_num_samples,
        trust_remote_code=args.dataset_trust_remote_code,
    )

    if args.val_split:
        dataset_eval = create_dataset(
            args.dataset,
            root=args.data_dir,
            split=args.val_split,
            is_training=False,
            class_map=args.class_map,
            download=args.dataset_download,
            batch_size=args.batch_size,
            input_img_mode=input_img_mode,
            input_key=args.input_key,
            target_key=args.target_key,
            num_samples=args.val_num_samples,
            trust_remote_code=args.dataset_trust_remote_code,
        )

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix,
            cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob,
            switch_prob=args.mixup_switch_prob,
            mode=args.mixup_mode,
            label_smoothing=args.smoothing,
            num_classes=args.num_classes
        )
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support de-interleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeline
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        train_crop_mode=args.train_crop_mode,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        color_jitter_prob=args.color_jitter_prob,
        grayscale_prob=args.grayscale_prob,
        gaussian_blur_prob=args.gaussian_blur_prob,
        auto_augment=args.aa,
        num_aug_repeats=args.aug_repeats,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        device=device,
        use_prefetcher=args.prefetcher,
        use_multi_epochs_loader=args.use_multi_epochs_loader,
        worker_seeding=args.worker_seeding,
    )

    loader_eval = None
    if args.val_split:
        eval_workers = args.workers
        if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
            # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
            eval_workers = min(2, args.workers)
        loader_eval = create_loader(
            dataset_eval,
            input_size=data_config['input_size'],
            batch_size=args.validation_batch_size or args.batch_size,
            is_training=False,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=eval_workers,
            distributed=args.distributed,
            crop_pct=data_config['crop_pct'],
            pin_memory=args.pin_mem,
            device=device,
            use_prefetcher=args.prefetcher,
        )

    # setup loss function
    if args.jsd_loss:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
    elif mixup_active:
        # smoothing is handled with mixup target transform which outputs sparse, soft targets
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(
                target_threshold=args.bce_target_thresh,
                sum_classes=args.bce_sum,
                pos_weight=args.bce_pos_weight,
            )
        else:
            train_loss_fn = SoftTargetCrossEntropy()
    elif args.smoothing:
        if args.bce_loss:
            train_loss_fn = BinaryCrossEntropy(
                smoothing=args.smoothing,
                target_threshold=args.bce_target_thresh,
                sum_classes=args.bce_sum,
                pos_weight=args.bce_pos_weight,
            )
        else:
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        train_loss_fn = nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.to(device=device)
    validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
    distillation_loss_fn = nn.MSELoss().to(device=device)

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric if loader_eval is not None else 'loss'
    decreasing_metric = eval_metric == 'loss'
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = None
    # Save directory
    if utils.is_primary(args):
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model),
                str(data_config['input_size'][-1])
            ])
        output_dir = utils.get_outdir(args.output if args.output else './output/results', exp_name)
        output_directory = output_dir
        # output_dir = output_directory
        
        saver = utils.CheckpointSaver(
            model=model,
            optimizer=optimizer,
            args=args,
            model_ema=model_ema,
            amp_scaler=loss_scaler,
            checkpoint_dir=output_dir,
            recovery_dir=output_dir,
            decreasing=decreasing_metric,
            max_history=args.checkpoint_hist
        )
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
        save_hardware(output_dir)

    if utils.is_primary(args) and args.log_wandb:
        if has_wandb:
            assert not args.wandb_resume_id or args.resume
            wandb.init(project=args.experiment, config=args, tags=args.wandb_tags,
                    resume='must' if args.wandb_resume_id else None,
                    id=args.wandb_resume_id if args.wandb_resume_id else None)
        else:
            _logger.warning(
                "You've requested to log metrics to wandb but package not found. "
                "Metrics not being logged to wandb, try `pip install wandb`")

    # setup learning rate schedule and starting epoch
    updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
    lr_scheduler, num_epochs = create_scheduler_v2(
        optimizer,
        **scheduler_kwargs(args, decreasing_metric=decreasing_metric),
        updates_per_epoch=updates_per_epoch,
    )
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        if args.sched_on_updates:
            lr_scheduler.step_update(start_epoch * updates_per_epoch)
        else:
            lr_scheduler.step(start_epoch)

    if utils.is_primary(args):
        if args.warmup_prefix:
            sched_explain = '(warmup_epochs + epochs + cooldown_epochs). Warmup added to total when warmup_prefix=True'
        else:
            sched_explain = '(epochs + cooldown_epochs). Warmup within epochs when warmup_prefix=False'
        _logger.info(
            f'Scheduled epochs: {num_epochs} {sched_explain}. '
            f'LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')

    results = []
    remained_list = [[], [], [], [], [], [], [], {}]

    # setting
    # loader = loader_test
    loader = loader_eval
    loss_fn = validate_loss_fn

    if args.gradient_base_model_explanation_statistic:
        reduce_factor = args.tta
        metrics = test_detection(
            model,
            loader,
            args,
            gradient_base_model_explanation=args.gradient_base_model_explanation,
            target_layers=target_layers,
            mean=args.mean,
            std=args.std,
            device=device,
            amp_autocast=amp_autocast,
            reduce_factor=reduce_factor,
            output_directory=output_directory,
        )

        utils.update_summary(
            0,
            {},
            metrics,
            filename=os.path.join(output_dir, 'detection_summary.csv'),
            lr=0,
            write_header=best_metric is None,
            log_wandb=args.log_wandb and has_wandb,
        )
        return

    if args.perturbation_base_model_explanation:
        assert 'toy' not in args.dataset, "perturbation now only operate for image dataset; not toy dataset"
        print('Run: test_iteration; cf. test functions will not work (clustering, visualization)')
        reduce_factor = args.tta
        test_iteration(
            model,
            loader,
            args,
            device=device,
            amp_autocast=amp_autocast,
            reduce_factor=reduce_factor,
            output_directory=output_directory,
        )
        return

    if args.use_anchorwise_feature or args.use_gathering_feature or args.use_layerwise_contiuous_feature or \
            args.use_identity_feature or args.use_anchoring_feature:
        use_clustering = True
        n_samples = args.n_accum_cluster * args.batch_size * args.grad_accum_steps
        n_iteration_for_cluster = args.n_accum_cluster * args.grad_accum_steps
        n_per_bin = n_samples // args.n_bin + 1 if n_samples % args.n_bin else n_samples // args.n_bin
    else:
        use_clustering = False
    if use_clustering:
        activations_keys = list(activations.keys())

    if (not args.no_similarity_visualization or not args.no_similarity_statistic or 
        not args.no_overall_features_visualization or not args.no_average_features_visualization or 
        not args.no_shifted_features_visualization or not args.no_features_visualization):
        use_similarity_or_visualization = True
    else:
        use_similarity_or_visualization = False

    metrics = test(
        model,
        loader_eval,
        validate_loss_fn,
        args,
        device=device,
        amp_autocast=amp_autocast,
        activations=activations,
        use_anchorwise_feature=args.use_anchorwise_feature,
        use_gathering_feature=args.use_gathering_feature,
        use_layerwise_contiuous_feature=args.use_layerwise_contiuous_feature,
        use_identity_feature=args.use_identity_feature,
        use_anchoring_feature=args.use_anchoring_feature,
        distillation_loss_fn=distillation_loss_fn,
        gradient_base_model_explanation=args.gradient_base_model_explanation,
        target_layers=target_layers,
        n_batch_explanation=args.n_batch_explanation,
        mean=args.mean,
        std=args.std,
        remained_list=remained_list,
        output_directory=output_directory
    )

    utils.update_summary(
        0,
        {},
        metrics,
        filename=os.path.join(output_dir, 'summary.csv'),
        lr=0,
        write_header=best_metric is None,
        log_wandb=args.log_wandb and has_wandb,
    )


def test(
        model,
        loader,
        loss_fn,
        args,
        device=torch.device('cuda'),
        amp_autocast=suppress,
        log_suffix='',
        without_print=False,
        activations=None,
        use_anchorwise_feature=False,
        use_gathering_feature=False,
        use_layerwise_contiuous_feature=False,
        use_identity_feature=False,
        use_anchoring_feature=False,
        distillation_loss_fn=None,
        gradient_base_model_explanation='',
        target_layers=None,
        n_batch_explanation=2,
        mean=0.5,
        std=0.5,
        remained_list=[[], [], [], [], [], [], [], {}],
        output_directory='',
):
    has_no_sync = hasattr(model, "no_sync")

    update_time_m = utils.AverageMeter()
    losses_m = utils.AverageMeter()
    losses_c = utils.AverageMeter() # classification
    losses_d = utils.AverageMeter() # distillation
    top1_m = utils.AverageMeter()
    top5_m = utils.AverageMeter()

    model.eval()

    if use_anchorwise_feature or use_gathering_feature or use_layerwise_contiuous_feature or \
        use_identity_feature or use_anchoring_feature:
        use_clustering = True
        # n_samples = args.n_accum_cluster * args.batch_size * args.grad_accum_steps
        # n_iteration_for_cluster = args.n_accum_cluster * args.grad_accum_steps
        # n_per_bin = n_samples // args.n_bin + 1 if n_samples % args.n_bin else n_samples // args.n_bin
    else:
        # use_clustering = False
        use_clustering = True

    clusters, layerwise_batch_features_list, layer_start_idx_list_train, token_start_idx_list_train, \
        layer_start_idx_list, token_start_idx_list, identity_nodes, activations_keys = remained_list

    # last_idx = len(loader) - 1
    n_samples = args.sqrt_n_vis_samples ** 2
    accum_steps = n_iteration_for_cluster = last_idx = \
        n_samples // args.batch_size + 1 if n_samples % args.batch_size else n_samples // args.batch_size
    remained_n = n_samples % args.batch_size
    assert len(loader) >= n_iteration_for_cluster
    n_per_bin = n_samples // args.n_bin + 1 if n_samples % args.n_bin else n_samples // args.n_bin
    with torch.no_grad():
        start = time.time()
        inputs = list()
        outputs = list()

        layerwise_batch_features = defaultdict(list)
        sample_count = 0
        update_start_time = time.time() # data_start_time = 
        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            if batch_idx == (last_idx-1) and remained_n:
                input = input[:remained_n]
                target = target[:remained_n]
            inputs.append(input)
            reduce_factor = args.tta
            if reduce_factor > 1:
                target = target[0:target.size(0):reduce_factor]

            if not args.prefetcher:
                input = input.to(device)
                target = target.to(device)
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)
            
            def _forward():
                with amp_autocast():
                    output = model(input)
                    if reduce_factor > 1:
                        output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    loss = loss_fn(output, target)
                    outputs.append(output)
                    if isinstance(output, (tuple, list)):
                        output = output[0]

                    # augmentation reduction
                    if reduce_factor > 1:
                        output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    
                    if use_clustering:
                        for layer_start_idx, token_start_idx, cur_group_cluster in zip(layer_start_idx_list_train, token_start_idx_list_train, clusters):
                            model_features = list()
                            for layer_index in range(layer_start_idx, layer_start_idx+args.layer_sampling_dim):
                                model_features.append(activations[activations_keys[layer_index]][:, token_start_idx:token_start_idx+args.token_sampling_dim,:])
                            model_features = torch.stack(model_features, dim=1).reshape(input.shape[0], -1)
                    distill_loss = []
                    if use_anchorwise_feature:
                        for cluster in clusters:
                            for nodes in cluster:
                                anchorwise_distill_loss = distillation_loss_fn(model_features[:, nodes[0]], model_features[:, nodes].detach().mean(1))
                                if anchorwise_distill_loss < loss:
                                    distill_loss.append(anchorwise_distill_loss)
                    if use_identity_feature:
                        for nodes in identity_nodes:
                            origin_target_features = model_features[:, nodes[0]].detach()
                            for node in range(1, len(nodes)):
                                identity_distill_loss = distillation_loss_fn(model_features[:, nodes[node]], origin_target_features)
                                if identity_distill_loss < loss:
                                    distill_loss.append(identity_distill_loss)
                    if len(distill_loss) > 0:
                        distill_loss = 0.1 * torch.stack(distill_loss).mean()
                        # print(f'loss {loss}, distill loss {distill_loss}')
                        total_loss = loss + distill_loss # loss = loss + distill_loss
                    else:
                        distill_loss = torch.tensor(0)
                        total_loss = loss
                    # use_anchorwise_feature or use_gathering_feature or use_layerwise_contiuous_feature or use_identity_feature or use_anchoring_feature:
                if accum_steps > 1:
                    loss /= accum_steps  
                
                acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
                return total_loss, loss, distill_loss, acc1, acc5

            if (gradient_base_model_explanation != ""):
                if n_batch_explanation <= batch_idx:
                    break
                with torch.enable_grad():
                    get_explanation_by_gradient(gradient_base_model_explanation, input, target,
                                                model, target_layers, amp_autocast, mean, std, 
                                                output_directory, batch_idx)
                continue

            if has_no_sync:
                with model.no_sync():
                    loss, class_loss, distill_loss, acc1, acc5 = _forward()
            else:
                loss, class_loss, distill_loss, acc1, acc5 = _forward()

            sample_count += input.size(0)

            if args.distributed:
                reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
                reduced_class_loss = utils.reduce_tensor(class_loss.data, args.world_size)
                reduced_distill_loss = utils.reduce_tensor(distill_loss.data, args.world_size)
                acc1 = utils.reduce_tensor(acc1.data, args.world_size)
                acc5 = utils.reduce_tensor(acc5.data, args.world_size)
                sample_count *= args.world_size
            else:
                reduced_loss = loss.data
                reduced_class_loss = class_loss.data
                reduced_distill_loss = distill_loss.data
                acc1 = acc1.data
                acc5 = acc5.data

            if device.type == 'cuda':
                torch.cuda.synchronize()
            elif device.type == "npu":
                torch.npu.synchronize()

            losses_m.update(reduced_loss.item(), input.size(0))
            losses_c.update(reduced_class_loss.item(), input.size(0))
            losses_d.update(reduced_distill_loss.item(), input.size(0))
            top1_m.update(acc1.item(), target.size(0))
            top5_m.update(acc5.item(), target.size(0))

            if utils.is_primary(args) and last_batch:
                log_name = 'Visualized_val' + log_suffix
                _logger.info(
                    f'{log_name}: [{batch_idx:>4d}/{last_idx}]  '
                    f'Time: {update_time_m.val:.3f}s, {sample_count / update_time_m.val:>7.2f}/s  '
                    f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f})  '
                    f'Loss@c: {losses_c.val:>7.3f} ({losses_c.avg:>6.3f})  '
                    f'Loss@d: {losses_d.val:>7.3f} ({losses_d.avg:>6.3f})  '
                    f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f})  '
                    f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
                )

            if use_clustering:
                if batch_idx == 0:
                    # Activations
                    activations_keys = list(activations.keys())
                    layer_dim = len(activations_keys)
                    for value in activations.values():
                        token_dim = value.shape[1]
                        hidden_dim = value.shape[2]
                        break
                    if use_identity_feature:
                        n_node_per_layer = token_dim * hidden_dim
                        layer_index_list = np.repeat(np.arange(layer_dim), n_node_per_layer)
                        node_index_list = np.tile(np.arange(n_node_per_layer), layer_dim)
                if len(layer_start_idx_list) == 0:
                    # Sampling submodel
                    layerwise_batch_features_list = [list() for _ in range(args.n_submodel_clustering)]
                    
                    if args.selected_layer_start_idx_list:
                        layer_start_idx_list = args.selected_layer_start_idx_list
                    else:
                        layer_start_idx_list = [np.random.randint(0, layer_dim - args.layer_sampling_dim + 1) for _ in range(args.n_submodel_clustering)]
                    if args.selected_token_start_idx_list:
                        token_start_idx_list = args.selected_token_start_idx_list
                    else:
                        token_start_idx_list = [np.random.randint(0, token_dim - args.token_sampling_dim + 1) for _ in range(args.n_submodel_clustering)]
                elif (len(layerwise_batch_features_list[0]) % n_iteration_for_cluster) == 0:
                    # Update clustering base on previous sampled submodel
                    layer_start_idx_list_train = layer_start_idx_list
                    token_start_idx_list_train = token_start_idx_list
                    prev_clusters = clusters

                    break

                with torch.no_grad():
                    for layerwise_batch_features, layer_start_idx, token_start_idx in zip(layerwise_batch_features_list, layer_start_idx_list, token_start_idx_list):
                        iteration_layerwise_batch_features = list()
                        for layer_index in range(layer_start_idx, layer_start_idx+args.layer_sampling_dim):
                            iteration_layerwise_batch_features.append(activations[activations_keys[layer_index]][:, token_start_idx:token_start_idx+args.token_sampling_dim,:])
                        layerwise_batch_features.append(torch.stack(iteration_layerwise_batch_features, dim=1).reshape(input.shape[0], -1))

                time_now = time.time()
                update_time_m.update(time.time() - update_start_time)
                update_start_time = time_now

        if (gradient_base_model_explanation != ""):
            metrics = OrderedDict([('loss', losses_m.avg), ('loss_c', losses_c.avg), ('loss_d', losses_d.avg),
                                ('top1', top1_m.avg), ('top5', top5_m.avg)])
            
            print(f'current do to save behavior gradient_base_model_explanation not work with other results algorithms')
            print(f'Done visualization: {output_directory}\n')

            return metrics

        # clustering for visualization
        clusters, diff_matrixs, threshold_diff_matrixs, binned_features_list, labels_list = \
            get_cluster(layerwise_batch_features_list, args, n_samples, n_per_bin)

        inputs = torch.cat(inputs, 0)
        outputs = torch.cat(outputs, 0)
        _, sorted_outputs_indices = torch.sort(outputs, dim=0)
        n_layer = args.layer_sampling_dim
        n_node_per_layer = args.token_sampling_dim * hidden_dim

        # Save clusters
        output_dir = f"{output_directory}/cluster.pkl"
        with open(output_dir, "wb") as f:
            pickle.dump(clusters, f)

        # TODO : save logging
        print(f'exp_name: {args.resume}\n')
        if args.cluster_size_frequency or args.n_skipped_layer_cluster or args.n_cluster_per_layer or \
            args.cluster_size_frequency_histogram:
            output_dir = f"{output_directory}/cluster"
            layer_index_list = []
            sorted_frequency_list = []
            existing_skipping_layer_list = []
            os.makedirs(output_dir)
                
            for i, (cur_clusters, labels, cur_layer_start_idx, cur_token_start_idx) in enumerate(zip(clusters, labels_list, layer_start_idx_list, token_start_idx_list)):
                cluster_info = f'{i}-{cur_layer_start_idx}-{cur_token_start_idx}'
                layer_index = np.repeat(np.arange(args.layer_sampling_dim), n_node_per_layer) + cur_layer_start_idx
                layer_index_list.append(layer_index)
                if args.cluster_size_frequency or args.cluster_size_frequency_histogram:
                    sorted_frequency = get_cluster_size_frequency(cur_clusters, output_dir, cluster_info)        
                    if args.cluster_size_frequency_histogram:
                        get_cluster_size_frequency_histogram(sorted_frequency, output_dir, cluster_info)
                    sorted_frequency_list.append(sorted_frequency)
                if args.n_skipped_layer_cluster:
                    existing_skipping_layer_list = get_n_skipped_layer_cluster(cur_clusters, layer_index, existing_skipping_layer_list, output_dir, cluster_info)
                if args.n_cluster_per_layer:
                    get_n_cluster_per_layer(n_layer, labels, n_node_per_layer, output_dir, cluster_info)

            print(f'Time cost for clustering and logging text : {time.time() - start}')

        # visualization and statistic
        # similarity (difference)
        get_difference_results(args, output_directory, diff_matrixs, threshold_diff_matrixs,
                               layer_start_idx_list, token_start_idx_list, n_layer, n_node_per_layer)

        # plot layer by layer
        x1, x2 = np.meshgrid(np.linspace(0, 1.0, args.sqrt_n_vis_samples),
                             np.linspace(0, 1.0, args.sqrt_n_vis_samples))
        cmap = matplotlib.colormaps.get_cmap('tab10')
        if not args.no_overall_features_visualization:
            print('plot layer')
            sqrt_n = math.ceil(math.sqrt(n_node_per_layer))
            for i, (binned_features, labels, cur_layer_start_idx, cur_token_start_idx) in enumerate(zip(binned_features_list, labels_list, layer_start_idx_list, token_start_idx_list)):
                model_binned_features = binned_features.T
                for layer_idx in tqdm(range(n_layer)):
                    fig, axs = plt.subplots(sqrt_n, sqrt_n, figsize=(sqrt_n, sqrt_n))

                    axs = np.array([axs]).flatten()
                    start_idx = n_node_per_layer * layer_idx
                    for node_idx in range(n_node_per_layer):
                        ax = axs[node_idx]
                        Z = model_binned_features[:, start_idx + node_idx].reshape(x1.shape).cpu().numpy()
                        ax.contourf(x1, x2, Z, alpha=0.7)
                        ax.text(0.5, 0.9, str(labels[start_idx + node_idx]), color='white', fontsize=12, 
                            ha='center', va='top', fontweight='bold') # , bbox=dict(facecolor='black', alpha=0.5, edgecolor='none'))
                        ax.set_xticks([])
                        ax.set_yticks([])
                        ax.set_xticklabels([])
                        ax.set_yticklabels([])
                        ax.set_xlim(0.0, 1.0)
                        ax.set_ylim(0.0, 1.0)

                        for spine in ax.spines.values():
                            spine.set_edgecolor(cmap.colors[labels[start_idx + node_idx]%10])
                            spine.set_linewidth(5.0)

                    for j in range(n_node_per_layer, len(axs)):
                        fig.delaxes(axs[j])

                    output_dir = f"{output_directory}/layer_{layer_idx}_{cur_layer_start_idx}_{cur_token_start_idx}.png"
                    plt.savefig(output_dir, dpi=300, bbox_inches='tight')
                    plt.close(fig)

        # plot cluster by cluster
        if not args.no_average_features_visualization or not args.no_features_visualization:
            for (layer_index, binned_features, sorted_frequency, existing_skipping_layer, cur_clusters, cur_layer_start_idx, cur_token_start_idx) in zip(layer_index_list, binned_features_list, sorted_frequency_list, existing_skipping_layer_list, clusters, layer_start_idx_list, token_start_idx_list):
                model_binned_features = binned_features.T
                plot_frequency_cluster = {frequency: 0 for frequency in sorted_frequency.keys()}
                plot_clusters = []
                for value in cur_clusters: # sorted_frequency
                    if plot_frequency_cluster[len(value)] < args.n_visualization_per_cluster:
                        plot_frequency_cluster[len(value)] += 1
                        plot_clusters.append(value)
                for i, key in enumerate(existing_skipping_layer):
                    if len(key):
                        plot_clusters.append(cur_clusters[i])

                if not args.no_average_features_visualization:
                    print('plot average cluster')
                    for cluster_index, nodes_index in enumerate(tqdm(plot_clusters)):
                        nodes_feature = model_binned_features[:, nodes_index]
                        node_layer_index = layer_index[nodes_index]
                        n_node_over_cluster = nodes_feature.shape[1]

                        # average feature, count of nodes
                        fig, ax = plt.subplots(figsize=(1, 1))
                        mean_Z = nodes_feature.mean(1).reshape(x1.shape).cpu().numpy()
                        ax.contourf(x1, x2, mean_Z, alpha=0.7)
                        ax.text(0.5, 0.5, str(n_node_over_cluster), color='white', fontsize=12, 
                                ha='center', va='center', fontweight='bold')
                        ax.set_xticks([])
                        ax.set_yticks([])
                        ax.set_xticklabels([])
                        ax.set_yticklabels([])
                        ax.set_xlim(0.0, 1.0)
                        ax.set_ylim(0.0, 1.0)

                        output_dir = f"{output_directory}/cluster_average_{cluster_index}_{cur_layer_start_idx}_{cur_token_start_idx}.png"
                        plt.savefig(output_dir, dpi=300, bbox_inches='tight')
                        plt.close(fig)

                        if not args.no_shifted_features_visualization:
                            print('plot cluster; substracting average')
                            nodes_feature = model_binned_features[:, nodes_index]
                            node_layer_index = layer_index[nodes_index]
                            n_node_over_cluster = nodes_feature.shape[1]

                            nodes_feature = nodes_feature.cpu().numpy()
                            sqrt_n = math.ceil(math.sqrt(n_node_over_cluster))
                            fig, axs = plt.subplots(sqrt_n, sqrt_n, figsize=(sqrt_n, sqrt_n))
                            
                            axs = np.array([axs]).flatten()
                            for i in range(n_node_over_cluster):
                                if i >= len(axs):
                                    break
                                ax = axs[i]
                                Z = nodes_feature[:, i].reshape(x1.shape) - mean_Z
                                ax.contourf(x1, x2, Z, alpha=0.7)
                                ax.text(0.5, 0.9,  str(node_layer_index[i]), color='white', fontsize=12, 
                                    ha='center', va='top', fontweight='bold')
                                ax.set_xticks([])
                                ax.set_yticks([])
                                ax.set_xticklabels([])
                                ax.set_yticklabels([])
                                ax.set_xlim(0.0, 1.0)
                                ax.set_ylim(0.0, 1.0)

                                for spine in ax.spines.values():
                                    spine.set_edgecolor(cmap.colors[node_layer_index[i]%10])
                                    spine.set_linewidth(5.0)

                            for j in range(n_node_over_cluster, len(axs)):
                                fig.delaxes(axs[j])
                            
                            output_dir = f"{output_directory}/shifted_cluster_{cluster_index}_{cur_layer_start_idx}_{cur_token_start_idx}.png"
                            plt.savefig(output_dir, dpi=300, bbox_inches='tight')
                            plt.close(fig)

                if not args.no_features_visualization:
                    print('plot cluster')
                    for cluster_index, nodes_index in enumerate(tqdm(plot_clusters)):
                        nodes_feature = model_binned_features[:, nodes_index]
                        node_layer_index = layer_index[nodes_index]
                        n_node_over_cluster = nodes_feature.shape[1]

                        nodes_feature = nodes_feature.cpu().numpy()
                        sqrt_n = math.ceil(math.sqrt(n_node_over_cluster))
                        fig, axs = plt.subplots(sqrt_n, sqrt_n, figsize=(sqrt_n, sqrt_n))
                        
                        axs = np.array([axs]).flatten()
                        for i in range(n_node_over_cluster):
                            if i >= len(axs):
                                break
                            ax = axs[i]
                            Z = nodes_feature[:, i].reshape(x1.shape)
                            ax.contourf(x1, x2, Z, alpha=0.7)
                            ax.text(0.5, 0.9,  str(node_layer_index[i]), color='white', fontsize=12, 
                                ha='center', va='top', fontweight='bold')
                            ax.set_xticks([])
                            ax.set_yticks([])
                            ax.set_xticklabels([])
                            ax.set_yticklabels([])
                            ax.set_xlim(0.0, 1.0)
                            ax.set_ylim(0.0, 1.0)

                            for spine in ax.spines.values():
                                spine.set_edgecolor(cmap.colors[node_layer_index[i]%10])
                                spine.set_linewidth(5.0)

                        for j in range(n_node_over_cluster, len(axs)):
                            fig.delaxes(axs[j])
                        
                        output_dir = f"{output_directory}/cluster_{cluster_index}_{cur_layer_start_idx}_{cur_token_start_idx}.png"
                        plt.savefig(output_dir, dpi=300, bbox_inches='tight')
                        plt.close(fig)
    
    metrics = OrderedDict([('loss', losses_m.avg), ('loss_c', losses_c.avg), ('loss_d', losses_d.avg),
                           ('top1', top1_m.avg), ('top5', top5_m.avg)])
    
    print(f'Done visualization: {output_directory}\n')

    return metrics


def test_iteration(
        model,
        loader,
        args,
        device=torch.device('cuda'),
        amp_autocast=suppress,
        reduce_factor=0,
        output_directory='',
):
    assert args.batch_size == 1, "test_iteration only works for batch_size argument as 1"
    model.eval()

    update_time_m = utils.AverageMeter()
    sample_count = 0

    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
            update_start_time = time.time()
            reduce_factor = args.tta
            if reduce_factor > 1:
                target = target[0:target.size(0):reduce_factor]

            if not args.prefetcher:
                input = input.to(device)
                target = target.to(device)
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            get_explanation_by_perturbation(
                args.perturbation_base_model_explanation,
                input,
                target,
                model,
                amp_autocast,
                reduce_factor,
                args.ups,
                args.prob_thresh,
                args.numCategories,
                args.node_prob_thresh,
                args.beam_width,
                args.max_num_roots,
                args.overlap_thresh,
                args.numSuccessors,
                args.num_roots_sag,
                args.maxRootSize,
                output_directory,
                batch_idx
            )

            update_time_m.update(time.time() - update_start_time)
            sample_count += 1
            if utils.is_primary(args):
                print(f'Iteration {sample_count}, Time: {update_time_m.val:.3f}s ')


def test_detection(
        model,
        loader,
        args,
        gradient_base_model_explanation='gradcam++',
        target_layers=None,
        log_suffix='',
        mean=0.5,
        std=0.5,
        device=torch.device('cuda'),
        amp_autocast=suppress,
        reduce_factor=0,
        output_directory='',
):
    
    has_no_sync = hasattr(model, "no_sync")

    update_time_m = utils.AverageMeter()
    top1_m = utils.AverageMeter()
    top5_m = utils.AverageMeter()
    en_pg_m = utils.AverageMeter()
    pg_m = utils.AverageMeter()

    model.eval()
    last_idx = len(loader) - 1
    with torch.no_grad():
        start = time.time()

        update_start_time = time.time() 
        sample_count = 0
        last_batch_idx = len(loader) - 1
        for batch_idx, (input, target) in enumerate(loader): # target : bbox, label
            last_batch = batch_idx == last_batch_idx
            reduce_factor = args.tta
            if reduce_factor > 1:
                target = target[0:target.size(0):reduce_factor]

            if not args.prefetcher:
                input = input.to(device)
                target = target.to(device)
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)
            
            # bbox target -> results.py energy_point_game, point_game
            # check bbox visualization (cf. check center crop)
            
            def energy_point_game(bbox, saliency_map):
            
                x1, y1, x2, y2 = bbox
                w, h = saliency_map.shape

                empty = np.zeros((w, h))
                empty[y1:y2, x1:x2] = 1
                mask_bbox = saliency_map * empty

                energy_bbox =  mask_bbox.sum()
                energy_whole = saliency_map.sum() + 0.000000000001

                proportion = energy_bbox / energy_whole
                return proportion

            def point_game(bbox, saliency_map):
            
                x1, y1, x2, y2 = bbox
                w, h = saliency_map.shape

                empty = np.zeros((w, h))
                empty[y1:y2, x1:x2] = 1
                mask_bbox = saliency_map * empty
                
                if mask_bbox.max() == saliency_map.max():
                    return 1
                else:
                    return 0

            def _forward():
                with amp_autocast():
                    output = model(input)
                    if reduce_factor > 1:
                        output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    if isinstance(output, (tuple, list)):
                        output = output[0]
                    # augmentation reduction
                    if reduce_factor > 1:
                        output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    acc1, acc5 = utils.accuracy(output, target[:,-1], topk=(1, 5))

                    with torch.enable_grad():
                        grayscale_cam = get_explanation_by_gradient(gradient_base_model_explanation, input, target[:,-1],
                                                                    model, target_layers, amp_autocast, mean, std, 
                                                                    output_directory, batch_idx, return_instead_save=True)

                        bbox = target[:, :-1]
                        attribution_map = grayscale_cam
                        en_pg_list, pg_list = list(), list()
                        for cur_bbox, cur_attribution_map in zip(bbox, attribution_map):
                            en_pg = energy_point_game(cur_bbox, cur_attribution_map)
                            pg = point_game(cur_bbox, cur_attribution_map)
                            en_pg_list.append(en_pg)
                            pg_list.append(pg)
                    en_pg = sum(en_pg_list)/len(en_pg_list)
                    pg = sum(pg_list)/len(pg_list)

                return acc1, acc5, en_pg, pg

            if has_no_sync:
                with model.no_sync():
                    acc1, acc5, en_pg, pg = _forward()
            else:
                acc1, acc5, en_pg, pg = _forward()

            sample_count += input.size(0)

            if args.distributed:
                acc1 = utils.reduce_tensor(acc1.data, args.world_size)
                acc5 = utils.reduce_tensor(acc5.data, args.world_size)
                # en_pg = utils.reduce_tensor(en_pg.data, args.world_size)
                # pg = utils.reduce_tensor(pg.data, args.world_size)
                sample_count *= args.world_size
            else:
                acc1 = acc1.data
                acc5 = acc5.data
                # en_pg = en_pg.data
                # pg = pg.data

            if device.type == 'cuda':
                torch.cuda.synchronize()
            elif device.type == "npu":
                torch.npu.synchronize()

            top1_m.update(acc1.item(), target.size(0))
            top5_m.update(acc5.item(), target.size(0))
            en_pg_m.update(en_pg, target.size(0))
            pg_m.update(pg, target.size(0))

            if utils.is_primary(args) and last_batch:
                log_name = 'Visualized_val' + log_suffix
                _logger.info(
                    f'{log_name}: [{batch_idx:>4d}/{last_idx}]  '
                    f'Time: {update_time_m.val:.3f}s, {sample_count / update_time_m.val:>7.2f}/s  '
                    f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f})  '
                    f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
                    f'en_pg_m: {en_pg_m.val:>7.3f} ({en_pg_m.avg:>7.3f})  '
                    f'pg_m: {pg_m.val:>7.3f} ({pg_m.avg:>7.3f})'
                )

            time_now = time.time()
            update_time_m.update(time.time() - update_start_time)
            update_start_time = time_now
    
    metrics = OrderedDict([('en_pg_m', en_pg_m.avg), ('pg_m', pg_m.avg),
                           ('top1', top1_m.avg), ('top5', top5_m.avg)])
    
    print(f'Done visualization: {output_directory}\n')

    return metrics


if __name__ == '__main__':
    main()
