#!/usr/bin/env python3
""" Evaluation Script
"""
import argparse
import logging
import os
import time
from pathlib import Path
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from functools import partial
import math

import torch
import torch.nn as nn
import torchvision.utils
import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDP

import numpy as np
import random

from timm import utils
from timm.data import resolve_data_config
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
from timm.utils import ApexScaler, NativeScaler
from timm.utils.model import unwrap_model

from cheem.models.wrappers import ContinualMoE
from cheem.models.prompt_wrappers import PromptedContinualMoE
from cheem.models.qkv_wrappers import QKVContinualMoE
from cheem.models.op_factories import op_factory
from cheem.utils.backbone_iterators import BACKBONE_ITERATORS
from cheem.utils.mean_functions import vit_cls_token_mean
from cheem.models.modules.vit import CHEEM_BLOCKS

from cheem.data import create_dataset_v2
from cheem.config_utils import Config
from cheem.timm_custom.data import create_loader_v2
from cheem.timm_custom.utils.distributed import distribute_stats

try:
    from apex import amp
    from apex.parallel import DistributedDataParallel as ApexDDP
    from apex.parallel import convert_syncbn_model
    has_apex = True
except ImportError:
    has_apex = False

has_native_amp = False
try:
    if getattr(torch.cuda.amp, 'autocast') is not None:
        has_native_amp = True
except AttributeError:
    pass

try:
    import wandb
    has_wandb = True
except ImportError:
    has_wandb = False

try:
    from functorch.compile import memory_efficient_fusion
    has_functorch = True
except ImportError as e:
    has_functorch = False

has_compile = hasattr(torch, 'compile')


_logger = logging.getLogger('eval')

# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

# Dataset parameters
group = parser.add_argument_group('Dataset parameters')
group.add_argument('--runtime-config-file', default='', type=str, metavar='RUNTIME_CONFIG',
                   help='YAML config file specifying ImageNet arguments')
group.add_argument('--benchmark', default='vdd', type=str, metavar='BENCHMARK',
                   help='Benchmark to run. Should be one of "vdd", "5-datasets"')
group.add_argument('--root_dir', metavar='ROOT',
                   help='root directory to the repository')
group.add_argument("--backbone_weights_path", default=None, type=str, help="ImageNet checkpoint containing the backbone weights.")
group.add_argument("--backbone_checkpoint_dir", default=None, type=str, help="ImageNet checkpoint containing mean calculations.")

group.add_argument('--dataset', '-d', metavar='NAME', default='',
                   help='dataset type (default: ImageFolder/ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',
                   help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
                   help='dataset validation split (default: validation)')
group.add_argument('--dataset-download', action='store_true', default=False,
                   help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
                   help='path to class to idx mapping file (default: "")')

# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
                   help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
                   help='Start with pretrained version of specified network (if avail)')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
                   help='Initialize model from this checkpoint (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
                   help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
                   help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=None, metavar='N',
                   help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',
                   help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
                   help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
                   help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
                   metavar='N N N',
                   help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=None, type=float,
                   metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
                   help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
                   help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
                   help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
                   help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
                   help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
                   help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
                   help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-checkpointing', action='store_true', default=False,
                   help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
                   help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)

scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
                             help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
                             help="Enable compilation w/ specified backend (default: inductor).")
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
                             help="Enable AOT Autograd support.")

# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
                   help='Disable all training augmentation, override other train aug args')
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
                   help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
                   help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
                   help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.,
                   help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                   help='Color jitter factor (default: 0.4)')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
                   help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
                   help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
                   help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
                   help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,
                   help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-target-thresh', type=float, default=None,
                   help='Threshold for binarizing softened BCE targets (default: None, disabled)')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
                   help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',
                   help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
                   help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
                   help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,
                   help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,
                   help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                   help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
                   help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
                   help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
                   help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
                   help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,
                   help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
                   help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                   help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
                   help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
                   help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                   help='Drop block rate (default: None)')

# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
                   help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
                   help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
                   help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
                   help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
                   help='Enable separate BN layers per augmentation split.')

# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
                   help='Enable tracking moving average of model weights')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
                   help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
                   help='decay factor for model weights moving average (default: 0.9998)')

# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
                   help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
                   help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
                   help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
                   help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=1, metavar='N',
                   help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
                   help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,
                   help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
                   help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
                   help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
                   help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
                   help='Force broadcast buffers for native DDP to off.')
group.add_argument('--pin-mem', action='store_true', default=False,
                   help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
                   help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
                   help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
                   help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
                   help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
                   help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
                   help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
                   help='log training and validation metrics to wandb')

# << 
group = parser.add_argument_group('Runtime settings')
group.add_argument('--wandb-project', default=None, type=str,
                   help='WandB project name')
group.add_argument('--wandb-username', default=None, type=str,
                   help='WandB username')
group.add_argument('--exp-name', default=None, type=str,
                   help='WandB run group.')
group.add_argument('--auto-scale-const',
                   type=float,
                   default=512.,
                   help='Scaling constant in Auto-scale lr: lr * total_bsz / auto_scale_const')
group.add_argument('--unit-test',
                   action='store_true',
                   default=False,
                   help='Unit Test flag for skipping full runs.')
group.add_argument('--dev-percent',
                   type=float,
                   default=0.,
                   help='Percent data used for development set')
#  >>

# << 
group = parser.add_argument_group(' modified settings')
group.add_argument('--use-deterministic',
                   action='store_true',
                   default=False,
                   help='set cuda to be deterministic')
#  >>

# << 
lll_group = parser.add_argument_group('Arguments for LLL settings')
lll_group.add_argument("--task-idx", 
                       default=None, 
                       type=int, 
                       help="Task Index for the task to train.")
lll_group.add_argument("--load-single-task", 
                       action="store_true",
                       default=False, 
                       help="Create dataset only for the current task.")
lll_group.add_argument("--offset_task_labels", 
                       action="store_true",
                       default=False, 
                       help="Add offsets to the class lables in the dataloaders. Used for single-head setting")
lll_group.add_argument("--cheem_component", 
                       default=None, 
                       type=str, 
                       help="Component of the backbone used as CHEEM.")


data_group = parser.add_argument_group('Arguments for Data settings')
data_group.add_argument('--data_root', metavar='DIR',
                    help='Root path to the dataset')
data_group.add_argument('--imagenet_root', metavar='DIR',
                    help='Root path to the ImageNet dataset')
data_group.add_argument("--cache_data", 
                       action="store_true",
                       default=False, 
                       help="Cache the dataset metadata.")
data_group.add_argument('--cache_root', default=None, 
                        type=str,
                        help='Root path to the cache')
#  >>

def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text


def get_model(backbone, benchmark_config, runtime, args):

    if args.task_idx > 0:
        arch_config_path = Path(runtime.config_dir, f"task_{args.task_idx}.yaml")
        with open(arch_config_path, "r") as f:
            arch_config = yaml.safe_load(f)
    else:
        arch_config = {}

    def head_factory(backbone):

        feature_dim = backbone.num_features

        return nn.ModuleList([
            nn.Linear(feature_dim, benchmark_config.task_config[task_name]['num_classes']) for task_name in runtime.task_order[1:args.task_idx+1]
        ])
    iter_func = BACKBONE_ITERATORS[args.cheem_component]
    mean_func = vit_cls_token_mean
    stat_funcs = {"mean": mean_func} if runtime.explore_probability < 1. else None

    if runtime.prompt_config["use_prompts"]:
        model = PromptedContinualMoE(
            args.task_idx, backbone, arch_config, iter_func, head_factory, op_factory, 
            prompt_len=runtime.prompt_config["prompt_len"], transfer_prompts=runtime.prompt_config["transfer_prompts"], 
            stat_funcs=stat_funcs, **runtime.op_factory_params
        )
    else:
        model_class = ContinualMoE if args.cheem_component in ["attn_proj", "ffn"] else QKVContinualMoE
        model = model_class(
            backbone, arch_config, iter_func, head_factory, op_factory, 
            stat_funcs=stat_funcs, **runtime.op_factory_params)

    return model


def main():

    dash_line = '-' * 60 + '\n'
    args, args_text = _parse_args()

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True

    if args.task_idx == 0:
        checkpoint_path = Path(args.backbone_checkpoint_dir, args.cheem_component)
    else:
        checkpoint_path = None

    _config = Config(
        args.benchmark, args.runtime_config_file, args.root_dir, args.exp_name, 
        wandb_username=args.wandb_username, wandb_project=args.wandb_project, 
        checkpoint_path=checkpoint_path)
    runtime = _config.runtime
    benchmark_config = _config.benchmark_config

    # Handle task specific
    args.task_name = runtime.task_order[args.task_idx]

    # Override the parameters
    args.img_size = runtime.img_size
    args.num_classes = benchmark_config.task_config["imagenet12"]['num_classes']
    args.drop_path = runtime.drop_path

    # Set up logging
    if args.task_idx > 0:
        supernet_config_file = Path(runtime.supernet_config_dir, f"{args.task_name}.yaml")
        with open(supernet_config_file, "r") as f:
            supernet_metadata = yaml.safe_load(f)
        _time = supernet_metadata["log_base"].split("/")[0]
    else:
        _time = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    log_dir = Path(runtime.log_dir, _time, "eval")
    log_dir.mkdir(parents=True, exist_ok=True)
    logfile = str(Path(log_dir, f"{args.task_name}.log"))

    args.output = log_dir

    utils.setup_default_logging(log_path=logfile)
    
    output_dir = None

    args.prefetcher = not args.no_prefetcher
    device = utils.init_distributed_device(args)
    if args.distributed:
        _logger.info(
            'Evaluating in distributed mode with multiple processes, 1 device per process.'
            f'Process {args.rank}, total {args.world_size}, device {args.device}.')
    else:
        _logger.info(f'Evaluating with a single process on 1 device ({args.device}).')
    assert args.rank >= 0

    if utils.is_primary(args):
        if args.experiment:
            exp_name = args.experiment
        else:
            exp_name = '-'.join([
                datetime.now().strftime("%Y%m%d-%H%M%S"),
                safe_model_name(args.model)
            ])
        output_dir = utils.get_outdir(args.output if args.output else f'{args.root_dir}/artifacts/{args.benchmark}', exp_name)
        Path(output_dir).mkdir(parents=True, exist_ok=True)

    if utils.is_primary(args) and args.log_wandb:

        if has_wandb:
            assert args.wandb_project is not None
            assert args.wandb_username is not None
            wandb_run_id = wandb.util.generate_id()
            run = wandb.init(project=runtime.wandb["project"], entity=runtime.wandb["entity"], group=args.exp_name, name=f"{args.task_name}_eval", config=args, id=wandb_run_id)
            run.config.update(_config.config_dict)
        else: 
            _logger.warning("You've requested to log metrics to wandb but package not found. "
                            "Metrics not being logged to wandb, try `pip install wandb`")

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    amp_dtype = torch.float16
    if args.amp:
        if args.amp_impl == 'apex':
            assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
            use_amp = 'apex'
            assert args.amp_dtype == 'float16'
        else:
            assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
            use_amp = 'native'
            assert args.amp_dtype in ('float16', 'bfloat16')
        if args.amp_dtype == 'bfloat16':
            amp_dtype = torch.bfloat16

    # << 
    if args.use_deterministic:
        torch.manual_seed(0)
        torch.cuda.manual_seed_all(0)
        np.random.seed(0)
        random.seed(0)
        torch.backends.cudnn.deterministic = True
    else:
        #  >>
        utils.random_seed(args.seed, args.rank)

    if args.fuser:
        utils.set_jit_fuser(args.fuser)
    if args.fast_norm:
        set_fast_norm()

    in_chans = 3
    if args.in_chans is not None:
        in_chans = args.in_chans
    elif args.input_size is not None:
        in_chans = args.input_size[0]

    pretrained_strict = not args.cheem_component in ["query", "key", "value"]

    backbone = create_model(
        args.model,
        pretrained=args.pretrained,
        in_chans=in_chans,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        block_fn=CHEEM_BLOCKS[args.cheem_component],
        pretrained_strict=pretrained_strict,
        **runtime.model_kwargs
    )

    backbone_state_dict_path = args.backbone_weights_path
    # Load the weights of the base checkpoint
    if utils.is_primary(args):
        _logger.info(f"Loading backbone checkpoint from {backbone_state_dict_path}.")
    state_dict = torch.load(backbone_state_dict_path, map_location="cpu")["state_dict"]
    backbone.load_state_dict(state_dict, strict=pretrained_strict)

    backbone.requires_grad_(False)

    model = get_model(backbone, benchmark_config, runtime, args)

    if utils.is_primary(args):
        _logger.info(f"{model}")

    model.set_task_idx(args.task_idx)

    # Load the weights if not the base task
    if args.task_idx > 0:
        # Finetune will save to temp_checkpoint_{task_idx}
        arch_checkpoint = Path(runtime.checkpoint_dir, f"temp_checkpoint_{args.task_idx}.pth.tar")
        _logger.info(f"Loading checkpoint from {arch_checkpoint}")
        experts_state_dict = torch.load(arch_checkpoint, map_location="cpu")

        model.load_state_dict(experts_state_dict)

    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.grad_checkpointing:
        model.set_grad_checkpointing(enable=True)

    # Actual parameter counts
    args.mparams = sum([m.numel() for m in model.parameters()])
    # Number of parameters without the head
    num_params_head = model.backbone.num_features * args.num_classes + args.num_classes
    params_without_head = args.mparams - num_params_head

    if utils.is_primary(args):
        if args.unit_test:
            _logger.info("Running in Unit Test mode. Full run will be skipped.")
        _logger.info(f"Using img_size={args.img_size}\n{dash_line}")
        
        _logger.info(
            f'Model {safe_model_name(args.model)} created, param count:{args.mparams}. param count without head: {params_without_head}')
        _logger.info(f'{args.mparams/1e6:.3f}M Params. Without Head: {params_without_head/1e6:.3f}M Params\n{dash_line}')

        if args.pretrained:
            _logger.info("Initializing from pretrained weights.")

    data_config = resolve_data_config(vars(args), model=backbone, verbose=utils.is_primary(args))

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.to(device=device)
    if args.channels_last:
        model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        args.dist_bn = ''  # disable dist_bn when sync BN active
        assert not args.split_bn
        if has_apex and use_amp == 'apex':
            # Apex SyncBN used with Apex AMP
            # WARNING this won't currently work with models using BatchNormAct2d
            model = convert_syncbn_model(model)
        else:
            model = convert_sync_batchnorm(model)
        if utils.is_primary(args):
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)
    elif args.torchcompile:
        # FIXME dynamo might need move below DDP wrapping? TBD
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
        torch._dynamo.reset()
        model = torch.compile(model, backend=args.torchcompile)
    elif args.aot_autograd:
        assert has_functorch, "functorch is needed for --aot-autograd"
        model = memory_efficient_fusion(model)

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        assert device.type == 'cuda'
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if utils.is_primary(args):
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        try:
            amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
        except (AttributeError, TypeError):
            # fallback to CUDA only AMP for PyTorch < 1.10
            assert device.type == 'cuda'
            amp_autocast = torch.cuda.amp.autocast
        if device.type == 'cuda' and amp_dtype == torch.float16:
            # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
            loss_scaler = NativeScaler()
        if utils.is_primary(args):
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if utils.is_primary(args):
            _logger.info('AMP not enabled. Training in float32.')

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
        model_ema = utils.ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if len([m for m in model.parameters() if m.requires_grad]) == 0:
            if utils.is_primary(args):
                _logger.info("No trainable parameters found, skipping data parallel.")
        elif has_apex and use_amp == 'apex':
            # Apex DDP preferred unless native amp is activated
            if utils.is_primary(args):
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if utils.is_primary(args):
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
        # NOTE: EMA model does not need to be wrapped by DDP

    eval_set = "val" if args.dev_percent > 0. else "test"

    # create the train and eval datasets
    datasets = create_dataset_v2(
        args.benchmark, 
        data_root=args.data_root,
        task_names=runtime.task_order,
        imagenet_root=args.imagenet_root, 
        task_idx=args.task_idx,
        load_single_task=args.load_single_task,
        cache=args.cache_data, 
        cache_root=args.cache_root,
        dev_percent=args.dev_percent, 
        offset_task_labels=args.offset_task_labels)

    dataset_idx = 0 if args.load_single_task else args.task_idx
    dataset_train = datasets["train"][dataset_idx]
    dataset_eval = datasets[eval_set][dataset_idx]

    if utils.is_primary(args):
        _logger.info(f"Eval Set: {eval_set}")
        train_size = len(dataset_train)
        eval_size = len(dataset_eval)
        _logger.info(f"Train Set Size: {train_size}, Eval Set Size: {eval_size}")

    # create data loaders w/ augmentation pipeiine
    eval_transforms = benchmark_config.eval_transforms[args.task_name]
    loader_train = create_loader_v2(
        args.task_name,
        dataset_train,
        input_size=data_config['input_size'],
        transform_list=eval_transforms,
        batch_size=args.validation_batch_size or args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
        device=device,
    )

    loaders_eval = []
    eval_workers = args.workers
    if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
        # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
        eval_workers = min(2, args.workers)
    for _task_idx, task_name in enumerate(runtime.task_order[:args.task_idx+1]):

        eval_transforms = benchmark_config.eval_transforms[task_name]
        _dataset_eval = datasets[eval_set][_task_idx]

        total_num_samples = len(_dataset_eval)
        batch_size = args.batch_size
        validation_batch_size = args.validation_batch_size
        if total_num_samples // args.batch_size < runtime.min_finetune_batches:
            if utils.is_primary(args):
                _logger.info(f"{task_name}: Reducing batch size for smaller dataset.")
            batch_size = batch_size // 8
            if validation_batch_size:
                validation_batch_size = validation_batch_size // 8

        loader_eval = create_loader_v2(
            task_name,
            _dataset_eval,
            input_size=data_config['input_size'],
            transform_list=eval_transforms,
            batch_size=validation_batch_size or batch_size,
            is_training=False,
            use_prefetcher=args.prefetcher,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            crop_pct=data_config['crop_pct'],
            pin_memory=args.pin_mem,
            device=device,
            drop_last=True,
            shuffle=False
        )

        loaders_eval.append(loader_eval)

    try:

        # Calculate the running means
        _model = unwrap_model(model)
        _model.set_task_idx(args.task_idx)
        if runtime.explore_probability < 1.0:
            
            if utils.is_primary(args):
                _logger.info("Calculating running means.")
            # Set the mean hooks
            _model.register_stat_hooks(args.task_idx)
            for running_mean_epoch in range(runtime.running_mean_epochs):
                # Just evaluate the model
                eval_metrics = evaluate(model, loader_train, args, amp_autocast=amp_autocast, unit_test=args.unit_test)
                if args.unit_test:
                    _logger.info("Unit Test. Exiting after 1 epoch.")
                    break
            # Make sure to remove the mean hooks to avoid updating them during predictions
            _model.remove_stat_hooks()

            if args.distributed:
                _logger.info(f"[Worker ID: {args.rank}] Distributing stats across devices")
                distribute_stats(model, _model.stat_funcs.keys(), args.world_size, reduce=True)

            if _model.stat_funcs is not None:
                for layer, (layer_name, stat_layer, dim) in enumerate(model.iter_backbone(model.backbone, mode="statistics")):
                    for stat in _model.stat_funcs:
                        model_stat = getattr(_model, stat)[layer]
                        for key in model_stat.expert_keys:
                            _stat = getattr(model_stat, key)
                            _logger.info(
                                    f"{layer} | {key} - Magnitude: {torch.sqrt((_stat**2).sum()):.4f}, Sum: {_stat.sum()}")

        top1_accs = ["top1"]
        top5_accs = ["top5"]
        columns = ["Metric"]
        if utils.is_primary(args):
            _logger.info("Accuracies")
        for idx, task_name in enumerate(runtime.task_order[:args.task_idx+1]):
            
            _model.set_task_idx(idx)

            if utils.is_primary(args):
                _logger.info(f"Evaluating task {task_name}, idx: {idx}")
            eval_metrics = evaluate(model, loaders_eval[idx], args, device=device, amp_autocast=amp_autocast)
            top1 = eval_metrics["top1"]
            top5 = eval_metrics["top5"]
            if utils.is_primary(args):
                _logger.info(f"Top1: {top1}")
                _logger.info(f"Top5: {top5}\n{dash_line}")

            top1_accs.append(top1)
            top5_accs.append(top5)
            columns.append(task_name)

        # Average
        avg_top1 = sum(top1_accs[1:]) / (len(top1_accs) - 1)
        avg_top5 = sum(top5_accs[1:]) / (len(top5_accs) - 1)
        columns.append("Average")
        top1_accs.append(avg_top1)
        top5_accs.append(avg_top5)

        # Log
        if utils.is_primary(args) and args.log_wandb and has_wandb:
            wandb.log({"test_accuracies": wandb.Table(columns=columns, data=[top1_accs, top5_accs])})
            wandb.log({"avg_top1": avg_top1, "avg_top5": avg_top5})
        if utils.is_primary(args):
            _logger.info(f"Average Top 1: {avg_top1}")
            _logger.info(f"Average Top 5: {avg_top5}")

        if utils.is_primary(args):
            # Checkpoint the model
            checkpoint_path = Path(runtime.checkpoint_dir, f"checkpoint_{args.task_idx}.pth.tar")
            _logger.info(f"Saving checkpoint to {checkpoint_path}")
            
            checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
            torch.save(_model.state_dict(), checkpoint_path)

            temp_checkpoint = Path(runtime.checkpoint_dir, f"temp_checkpoint_{args.task_idx}.pth.tar")

            if args.task_idx > 0:
                # Remove the older checkpoint
                _logger.info(f"Removing old checkpoint from {temp_checkpoint}")
                temp_checkpoint.unlink()

            if args.task_idx > 1:

                prev_task_checkpoint = Path(runtime.checkpoint_dir, f"checkpoint_{args.task_idx-1}.pth.tar")
                prev_task_checkpoint.unlink()

    except KeyboardInterrupt:
        pass


def evaluate(
        model,
        loader,
        args,
        device=torch.device('cuda'),
        amp_autocast=suppress,
        log_suffix='',
        unit_test=False,
        **kwargs
):

    batch_time_m = utils.AverageMeter()
    top1_m = utils.AverageMeter()
    top5_m = utils.AverageMeter()

    model.eval()

    end = time.time()
    last_idx = len(loader) - 1
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            if not args.prefetcher:
                input = input.to(device)
                target = target.to(device)
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            with amp_autocast():
                output = model(input, **kwargs)
                if isinstance(output, (tuple, list)):
                    output = output[0]

                # augmentation reduction
                reduce_factor = args.tta
                if reduce_factor > 1:
                    output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    target = target[0:target.size(0):reduce_factor]

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))

            if args.distributed:
                acc1 = utils.reduce_tensor(acc1, args.world_size) # mean reduce
                acc5 = utils.reduce_tensor(acc5, args.world_size) # mean reduce

            if device.type == 'cuda':
                torch.cuda.synchronize()

            top1_m.update(acc1.item(), output.size(0))
            top5_m.update(acc5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
            if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
                log_name = 'Test' + log_suffix
                _logger.info(
                    '{0}: [{1:>4d}/{2}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                        log_name, batch_idx, last_idx, batch_time=batch_time_m,
                        top1=top1_m, top5=top5_m))

    metrics = OrderedDict([('top1', top1_m.avg), ('top5', top5_m.avg)])

    return metrics


if __name__ == '__main__':
    main()
