import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import yaml
from pathlib import Path
from timm.data import Mixup
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler
from lib.datasets import build_dataset
from supernet_engine_prompt import *
from lib.samplers import RASampler
from lib import utils
from lib.config import cfg, update_config_from_file
from model.supernet_vision_transformer_timm import VisionTransformer

import model as models
from timm.models import load_checkpoint

#SLUMRM
from mmcv.runner import get_dist_info, init_dist

import os
from timm.utils.clip_grad import dispatch_clip_grad
from collections import OrderedDict
from lib import utils


class MaskScalerTakeModelOpt:
    state_dict_key = "amp_scaler"

    def __init__(self, grad_mask_dict=None, freezed_state_dict=None):
        self._scaler = torch.cuda.amp.GradScaler()
        self.grad_mask_dict = grad_mask_dict
        self.freezed_state_dict = freezed_state_dict

    def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', model=None, create_graph=False, grad_mask=None):

        assert hasattr(self, 'freezed_state_dict')
        self._scaler.scale(loss).backward(create_graph=create_graph)

        # parameters = self.add_mask(parameters)
        # assert sum(1 for _ in parameters.clone()) == len(key_list)
        # for name, p in model.module.named_parameters():
        #     if p.grad is not None and name in self.grad_mask.keys():
        #
        #         grad_mask = self.grad_mask[name].to(p.device)
        #         p.grad = torch.where(grad_mask==1, p.grad, grad_mask)

        if clip_grad is not None:
            self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
            dispatch_clip_grad(model.parameters(), clip_grad, mode=clip_mode)

        self._scaler.step(optimizer)
        self._scaler.update()

        fully_tuned = []
        tmp_state_dict = {}
        for key in model.state_dict().keys():

            if key in self.grad_mask_dict.keys():
                grad_mask = self.grad_mask_dict[key].cuda()
                tmp_state_dict[key] = torch.where(grad_mask==1, model.state_dict()[key], self.freezed_state_dict[key].cuda())
            else:

                fully_tuned.append(key)
                tmp_state_dict[key] = model.state_dict()[key]

        model.load_state_dict(tmp_state_dict)

        # for key in model.module.state_dict().keys():
        #     print((self.freezed_state_dict[key] == model.module.state_dict()[key]).sum())

        # a = torch.cat([value.flatten() for value in self.freezed_state_dict.values()])
        # b = torch.cat([value.flatten() for value in model.module.state_dict().values()])
        # print((a != b).sum())

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


# Should be same as naive as we freeze the params
class MaskScalerTakeModelOptCG:
    state_dict_key = "amp_scaler"

    def __init__(self, grad_mask_dict=None, freezed_state_dict=None):
        self._scaler = torch.cuda.amp.GradScaler()
        self.grad_mask_dict_cg = grad_mask_dict
        self.freezed_state_dict = freezed_state_dict

    def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', model=None, create_graph=False, grad_mask=None):

        assert hasattr(self, 'freezed_state_dict')
        self._scaler.scale(loss).backward(create_graph=create_graph)

        # parameters = self.add_mask(parameters)
        # assert sum(1 for _ in parameters.clone()) == len(key_list)
        # for name, p in model.module.named_parameters():
        #     if p.grad is not None and name in self.grad_mask.keys():
        #
        #         grad_mask = self.grad_mask[name].to(p.device)
        #         p.grad = torch.where(grad_mask==1, p.grad, grad_mask)

        if clip_grad is not None:
            self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
            dispatch_clip_grad(model.parameters(), clip_grad, mode=clip_mode)

        self._scaler.step(optimizer)
        self._scaler.update()

        fully_tuned = []
        tmp_state_dict = {}
        for key in model.state_dict().keys():

            if key in self.grad_mask_dict_cg:

                fully_tuned.append(key)
                tmp_state_dict[key] = model.state_dict()[key]

            else:

                tmp_state_dict[key] = self.freezed_state_dict[key].cuda()

        model.load_state_dict(tmp_state_dict)

        # for key in model.module.state_dict().keys():
        #     print((self.freezed_state_dict[key] == model.module.state_dict()[key]).sum())

        # a = torch.cat([value.flatten() for value in self.freezed_state_dict.values()])
        # b = torch.cat([value.flatten() for value in model.module.state_dict().values()])
        # print((a != b).sum())

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


class MaskScalerTakeModelOptCGFG:
    state_dict_key = "amp_scaler"

    def __init__(self, fully_ft_list=None, partial_ft_dict=None, freezed_state_dict=None):
        self._scaler = torch.cuda.amp.GradScaler()
        self.fully_ft_list = fully_ft_list
        self.partial_ft_dict = partial_ft_dict
        self.freezed_state_dict = freezed_state_dict

    def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', model=None, create_graph=False, grad_mask=None):

        assert hasattr(self, 'freezed_state_dict')
        self._scaler.scale(loss).backward(create_graph=create_graph)

        # parameters = self.add_mask(parameters)
        # assert sum(1 for _ in parameters.clone()) == len(key_list)
        # for name, p in model.module.named_parameters():
        #     if p.grad is not None and name in self.grad_mask.keys():
        #
        #         grad_mask = self.grad_mask[name].to(p.device)
        #         p.grad = torch.where(grad_mask==1, p.grad, grad_mask)

        # Testing whether the parameters are really freezed
        # tmp = torch.cat([model.state_dict()[key].flatten() for key in model.state_dict().keys() if key in self.fully_ft_list])
        # tmp1 = torch.cat([model.state_dict()[key].flatten() for key in model.state_dict().keys() if key in self.partial_ft_dict])
        # tmp2 = torch.cat([self.freezed_state_dict[key].flatten() for key in self.freezed_state_dict.keys()])
        # print(tmp.sum(), tmp1.sum(), tmp2.sum())

        if clip_grad is not None:
            self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
            dispatch_clip_grad(model.parameters(), clip_grad, mode=clip_mode)

        self._scaler.step(optimizer)
        self._scaler.update()

        fully_tuned = []
        partially_tuned = []
        freezed = []

        tmp_state_dict = {}
        for key in model.state_dict().keys():

            if key in self.fully_ft_list:

                fully_tuned.append(key)
                tmp_state_dict[key] = model.state_dict()[key]

            elif key in self.partial_ft_dict:
                grad_mask = self.partial_ft_dict[key].cuda()
                # print(grad_mask.sum())
                tmp_state_dict[key] = torch.where(grad_mask == 1, model.state_dict()[key],
                                                  self.freezed_state_dict[key].cuda().detach())
                partially_tuned.append(key)

            else:
                # Fully freezed
                tmp_state_dict[key] = self.freezed_state_dict[key].cuda().detach()
                freezed.append(key)

        model.load_state_dict(tmp_state_dict)

        # for key in model.module.state_dict().keys():
        #     print((self.freezed_state_dict[key] == model.module.state_dict()[key]).sum())

        # a = torch.cat([value.flatten() for value in self.freezed_state_dict.values()])
        # b = torch.cat([value.flatten() for value in model.module.state_dict().values()])
        # print((a != b).sum())

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


def get_args_parser():
    parser = argparse.ArgumentParser('AutoFormer training and evaluation script', add_help=False)
    parser.add_argument('--batch-size', default=64, type=int)
    parser.add_argument('--epochs', default=300, type=int)
    # config file
    parser.add_argument('--cfg',help='experiment configure file name',required=True,type=str)

    # custom parameters
    parser.add_argument('--platform', default='pai', type=str, choices=['itp', 'pai', 'aml'],
                        help='Name of model to train')
    parser.add_argument('--teacher_model', default='', type=str,
                        help='Name of teacher model to train')
    parser.add_argument('--relative_position', action='store_true')
    parser.add_argument('--gp', action='store_true')
    parser.add_argument('--change_qkv', action='store_true')
    parser.add_argument('--max_relative_position', type=int, default=14, help='max distance in relative position embedding')

    # Model parameters
    parser.add_argument('--model', default='', type=str, metavar='MODEL',
                        help='Name of model to train')
    # AutoFormer config
    parser.add_argument('--mode', type=str, default='super', choices=['super', 'vp','retrain','search'], help='mode of AutoFormer')
    parser.add_argument('--input-size', default=224, type=int)
    parser.add_argument('--patch_size', default=16, type=int)

    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')
    parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
                        help='Drop block rate (default: None)')

    parser.add_argument('--model-ema', action='store_true')
    parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
    # parser.set_defaults(model_ema=True)
    parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
    parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
    parser.add_argument('--rpe_type', type=str, default='bias', choices=['bias', 'direct'])
    parser.add_argument('--post_norm', action='store_true')
    parser.add_argument('--no_abs_pos', action='store_true')

    # Optimizer parameters
    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                        help='Optimizer (default: "adamw"')
    parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                        help='Optimizer Epsilon (default: 1e-8)')
    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                        help='Optimizer Betas (default: None, use opt default)')
    parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')

    # Learning rate schedule parameters
    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
                        help='LR scheduler (default: "cosine"')
    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
                        help='learning rate (default: 5e-4)')
    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                        help='learning rate noise on/off epoch percentages')
    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                        help='learning rate noise limit percent (default: 0.67)')
    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                        help='learning rate noise std-dev (default: 1.0)')
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
                        help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
    parser.add_argument('--lr-power', type=float, default=1.0,
                        help='power of the polynomial lr scheduler')

    parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
                        help='epoch interval to decay LR')
    parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                        help='patience epochs for Plateau LR scheduler (default: 10')
    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                        help='LR decay rate (default: 0.1)')

    # Augmentation parameters
    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                        help='Color jitter factor (default: 0.4)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + \
                             "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='bicubic',
                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

    parser.add_argument('--repeated-aug', action='store_true')
    # parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')


    # parser.set_defaults(repeated_aug=True)

    # * Random Erase params
    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')

    # * Mixup params
    parser.add_argument('--mixup', type=float, default=0.8,
                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
    parser.add_argument('--cutmix', type=float, default=1.0,
                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup-prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup-mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

    # Dataset parameters
    parser.add_argument('--data-path', default='./data/imagenet/', type=str,
                        help='dataset path')
     # parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19', 'EVO_IMNET'],
    #                     type=str, help='Image Net dataset path')
    parser.add_argument('--data-set', default='IMNET', type=str, help='Image Net dataset path')
    parser.add_argument('--inat-category', default='name',
                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
                        type=str, help='semantic granularity')

    parser.add_argument('--output_dir', default='./',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
    parser.add_argument('--pin-mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
                        help='')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')

    parser.add_argument('--amp', action='store_true')
    parser.add_argument('--no-amp', action='store_false', dest='amp')
    # parser.set_defaults(amp=True)

    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')

    parser.add_argument('--is_visual_prompt_tuning', action='store_true')
    parser.add_argument('--is_adapter', action='store_true')
    parser.add_argument('--is_LoRA', action='store_true')
    parser.add_argument('--is_prefix', action='store_true')

    parser.add_argument('--no_aug', action='store_true')

    parser.add_argument('--val_interval', default=1, type=int, help='validataion interval')

    parser.add_argument('--drop_rate_LoRA', type=float, default=0.1)
    parser.add_argument('--drop_rate_prompt', type=float, default=0.1)
    parser.add_argument('--drop_rate_adapter', type=float, default=0.1)

    parser.add_argument('--few-shot-seed', type=int, default=0)
    parser.add_argument('--few-shot-shot', type=int, default=2)

    parser.add_argument('--inception',action='store_true')
    parser.add_argument('--direct_resize',action='store_true')

    parser.add_argument('--IS_not_position_VPT',action='store_true')

    # Our params
    parser.add_argument('--exp-name', default='', type=str)
    parser.add_argument('--get-grad', action='store_true')
    parser.add_argument('--get-grad-cg', action='store_true')
    parser.add_argument('--freeze_stage', action='store_true')
    parser.add_argument('--grad_mask_path', default='', type=str,)
    parser.add_argument('--scaler', default='naive', type=str,)
    parser.add_argument('--IECG_dim', default=8, type=int,)
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--save_best', action='store_true')
    parser.add_argument('--tune_no_norm', action='store_true')
    parser.add_argument('--freeze_others', default='True', type=str)
    parser.add_argument('--freeze_selected', action='store_true')
    parser.add_argument('--ft_cls_token', action='store_true')
    parser.add_argument('--block', type=str, default='BlockCGSepQKV')
    parser.add_argument('--tune_no_fg', action='store_true')
    parser.add_argument('--orth_loss', action='store_true')
    parser.add_argument('--mmd_loss', action='store_true')
    parser.add_argument('--orth_loss_sigma', type=float, default=0.01)
    parser.add_argument('--loss_type', type=str, default='crs_entropy')
    # parser.add_argument('--simp_aug', action='store_true')

    return parser


def main(args):

    # utils.init_distributed_mode(args)
    update_config_from_file(args.cfg)
    if args.launcher == 'none':
        args.distributed = False
    else:
        args.distributed = True
        init_dist(launcher=args.launcher)
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()

    print(args)
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)
    cudnn.benchmark = True
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args,is_individual_prompt=(args.is_visual_prompt_tuning or args.is_adapter or args.is_LoRA or args.is_prefix))
    dataset_val, _ = build_dataset(is_train=False, args=args,is_individual_prompt=(args.is_visual_prompt_tuning or args.is_adapter or args.is_LoRA or args.is_prefix))

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, batch_size=int(2 * args.batch_size),
        sampler=sampler_val, num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=False
    )

    print(f"{args.data_set} dataset, train: {len(dataset_train)}, evaluation: {len(dataset_val)}")

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    print('mixup_active',mixup_active)
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)

    print(f"Creating SuperVisionTransformer")
    print(cfg)

    if args.scaler == 'mask_take_model_opt_cg':
        param_info = torch.load(args.grad_mask_path, map_location='cpu')
        bias_names = param_info['bias_names']

        if args.tune_no_norm:
            bias_names = [name for name in bias_names if 'norm' not in name]

        weight_names = param_info['weight_names']

        print('Solely CG Tunable: ', bias_names)
        print(weight_names)
        print('Total params: ', param_info['params'])

    # Use naive scaler while load bias, weight names for trainable params
    elif args.scaler == 'mask_take_model_opt_cg_new':
        param_info = torch.load(args.grad_mask_path, map_location='cpu')
        bias_names = param_info['bias_names']

        if args.tune_no_norm:
            bias_names = [name for name in bias_names if 'norm' not in name]

        weight_names = param_info['weight_names']

        print('Solely CG Tunable: ', bias_names)
        print(weight_names)
        print('Total params: ', param_info['params'])

    elif args.scaler == 'mask_take_model_opt_cg_fg' or args.scaler == 'mask_take_model_opt_cg_fgv2':
        param_info = torch.load(args.grad_mask_path, map_location='cpu')
        bias_names = param_info['bias_names']

        weight_names = param_info['weight_names']
        # bias_names.extend(utils.flags2weight_list(weight_names))

        # print('Solely CG Tunable: ', bias_names)
        print('Both CG and FG tuning', )
        print('weight names: ', weight_names)
        print('bias names: ', bias_names)

        print('Total params: ', param_info['params'])

    elif args.scaler == 'mask_take_model_opt_cg_fgv3':
        param_info = torch.load(args.grad_mask_path, map_location='cpu')
        weight_names = [list(param_info['grad_indexes'].values())[i:i + 6] for i in range(0, len(list(param_info['grad_indexes'].values())), 6)]
        bias_names = []
        print('Total params: ', param_info['params'])

    elif args.scaler == 'mask_take_model_opt_cg_fgv3_norm':
        param_info = torch.load(args.grad_mask_path, map_location='cpu')
        weight_names = [list(param_info['grad_indexes'].values())[i:i + 16] for i in range(0, len(list(param_info['grad_indexes'].values())), 16)]
        bias_names = []
        print('Total params: ', param_info['params'])

    model = models.__dict__[cfg.MODEL_NAME](img_size=args.input_size,
                                            drop_rate=args.drop,
                                            drop_path_rate=args.drop_path,
                                            super_prompt_tuning_dim=cfg.SUPERNET.VISUAL_PROMPT_DIM,super_LoRA_dim=cfg.SUPERNET.LORA_DIM,super_adapter_dim=cfg.SUPERNET.ADAPTER_DIM,super_prefix_dim=cfg.SUPERNET.PREFIX_DIM,
                                            drop_rate_LoRA=args.drop_rate_LoRA,drop_rate_prompt=args.drop_rate_prompt,drop_rate_adapter=args.drop_rate_adapter,
                                            IS_not_position_VPT = args.IS_not_position_VPT, freeze_backbone=args.freeze_stage,
                                            IECG_list=weight_names, bias_names=bias_names, IECG_dim=args.IECG_dim, freeze_others=args.freeze_others, block=args.block
                                            )

    choices = {'depth': cfg.SUPERNET.DEPTH,
               'super_prompt_tuning_dim':cfg.SUPERNET.VISUAL_PROMPT_DIM,
               'super_LoRA_dim':cfg.SUPERNET.LORA_DIM,
               'super_adapter_dim':cfg.SUPERNET.ADAPTER_DIM,
               'super_prefix_dim':cfg.SUPERNET.PREFIX_DIM,
               'visual_prompt_dim':cfg.SEARCH_SPACE.VISUAL_PROMPT_DIM,
               'lora_dim':cfg.SEARCH_SPACE.LORA_DIM,
               'adapter_dim':cfg.SEARCH_SPACE.ADAPTER_DIM,
               'prefix_dim':cfg.SEARCH_SPACE.PREFIX_DIM,
               'visual_prompt_depth':cfg.SEARCH_SPACE.VISUAL_PROMPT_DEPTH,
               'lora_depth':cfg.SEARCH_SPACE.LORA_DEPTH,
               'adapter_depth':cfg.SEARCH_SPACE.ADAPTER_DEPTH,
               'prefix_depth':cfg.SEARCH_SPACE.PREFIX_DEPTH,
               }

    train_engine = train_one_epoch
    test_engine = evaluate

    if args.resume:
        if '.pth' in args.resume:

            if args.resume.startswith('mae'):
                state_dict = torch.load(args.resume, map_location='cpu')['model']
                new_dict = OrderedDict()
                for name in state_dict.keys():
                    if 'attn.qkv.' in name:
                        new_dict[name.replace('qkv', 'q')] = state_dict[name][:state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'k')] = state_dict[name][state_dict[name].shape[0] // 3:-state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'v')] = state_dict[name][-state_dict[name].shape[0] // 3:]
                    else:
                        new_dict[name] = state_dict[name]

                args.suffix = 'mae'
                msg = model.load_state_dict(new_dict, strict=False)
                print('Resuming from MAE model: ', msg)

            elif args.resume.startswith('linear'):
                state_dict = torch.load(args.resume, map_location='cpu')['state_dict']
                new_dict = OrderedDict()
                for name in state_dict.keys():
                    if 'attn.qkv.' in name:
                        new_dict[name.replace('qkv', 'q').split('module.')[1]] = state_dict[name][:state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'k').split('module.')[1]] = state_dict[name][state_dict[name].shape[0] // 3:-state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'v').split('module.')[1]] = state_dict[name][-state_dict[name].shape[0] // 3:]
                    elif 'head.' in name:
                        continue
                    else:
                        new_dict[name.split('module.')[1]] = state_dict[name]

                args.suffix = 'moco'
                msg = model.load_state_dict(new_dict, strict=False)
                print('Resuming from Moco model: ', msg)

            elif args.resume.startswith('swin'):

                state_dict = torch.load(args.resume, map_location='cpu')['model']
                new_dict = OrderedDict()
                for name in state_dict.keys():
                    if 'attn.qkv.' in name:
                        new_dict[name.replace('qkv', 'q')] = state_dict[name][:state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'k')] = state_dict[name][state_dict[name].shape[0] // 3:-state_dict[name].shape[0] // 3]
                        new_dict[name.replace('qkv', 'v')] = state_dict[name][-state_dict[name].shape[0] // 3:]
                    elif 'head.' in name:
                        continue
                    else:
                        new_dict[name] = state_dict[name]

                if args.nb_classes != model.head.weight.shape[0]:
                    model.reset_classifier(args.nb_classes)

                args.suffix = 'swin'
                msg = model.load_state_dict(new_dict, strict=False)
                print('Resuming from Swin model: ', msg)

                train_engine = train_one_epoch_swin
                test_engine = evaluate_swin

            else:

                if args.nb_classes != model.head.weight.shape[0]:
                    model.reset_classifier(args.nb_classes)
                incompatible_keys = load_checkpoint(model, args.resume,strict=False)
                print('Resuming from .pth model: ', incompatible_keys)

        else:
            load_checkpoint(model, args.resume)
            if args.nb_classes != model.head.weight.shape[0]:
                model.reset_classifier(args.nb_classes)

            args.suffix = 'vit'

        if args.scaler == 'mask_take_model_opt' or args.scaler == 'mask_take_model_opt_only_qkv':
            param_info = torch.load(args.grad_mask_path, map_location='cpu')
            shapes = param_info['grad_shapes']
            shapes_int = param_info['grad_shapes_int']
            indexes = param_info['indexes']

            # freezed_state_dict used to reload the not-trained parameters, grad_mask is used to select index
            freezed_state = torch.cat([model.state_dict()[key].flatten() for key in model.state_dict().keys() if not ('head' in key or 'cls_token' in key or 'adapter' in key or 'LoRA' in key or 'prefix' in key)])
            assert param_info['total_params'] == freezed_state.shape[0]
            freezed_state = freezed_state.split([shape_int for shape_int in shapes_int.values()])

            grad_mask = torch.zeros(param_info['total_params'])
            grad_mask[indexes] = 1.
            grad_mask = grad_mask.split([shape_int for shape_int in shapes_int.values()])

            grad_mask_dict = {}
            freezed_state_dict = {}
            for i, key in enumerate(shapes_int.keys()):
                freezed_state_dict[key] = freezed_state[i].view(shapes[key])
                grad_mask_dict[key] = grad_mask[i].view(shapes[key])

            # print('qkv params: ', torch.cat([(grad_mask_dict[key].flatten() == 1.) for key in shapes_int.keys() if '.qkv' in key]).sum())

    model.to(device)
    if args.teacher_model:
        teacher_model = create_model(
            args.teacher_model,
            pretrained=True,
            num_classes=args.nb_classes,
        )
        teacher_model.to(device)
        teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        teacher_model = None
        teacher_loss = None

    model_ema = None

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    # linear lr
    # linear_scaled_lr =  args.lr * args.batch_size * utils.get_world_size() / 128.0
    # args.lr = linear_scaled_lr

    optimizer = utils.build_optimizer(args, model_without_ddp)
    if args.scaler == 'naive' or args.scaler == 'mask_take_model_opt_cg_fgv3' or args.scaler == 'mask_take_model_opt_cg_fgv3_norm' or args.scaler == 'mask_take_model_opt_cg_new':
        loss_scaler = NativeScaler()

    # elif args.scaler == 'mask_take_model_opt':
    #     loss_scaler = MaskScalerTakeModelOpt(grad_mask_dict=grad_mask_dict, freezed_state_dict=freezed_state_dict)

    elif args.scaler == 'mask_take_model_opt_cg':
        assert args.resume
        freezed_state_dict = {key: model.state_dict()[key] for key in model.state_dict().keys()}
        bias_names.extend(['cls_token', 'head.weight', 'head.bias'])
        bias_names.extend([key for key in model.state_dict().keys() if 'IECG' in key])
        loss_scaler = MaskScalerTakeModelOptCG(grad_mask_dict=bias_names, freezed_state_dict=freezed_state_dict)

    elif args.scaler == 'mask_take_model_opt_cg_fg':
        assert args.resume
        freezed_state_dict = {key: model.state_dict()[key].clone().detach() for key in model.state_dict().keys()}
        bias_names.extend(['head.weight', 'head.bias'])

        if args.ft_cls_token:
            bias_names.append('cls_token')

        bias_names.extend([key for key in model.state_dict().keys() if 'IECG' in key])

        # FG settings
        # param_info = torch.load(args.grad_mask_path, map_location='cpu')
        fg_name_shapes = param_info['fg_name_shapes']
        # shapes_int = param_info['grad_shapes_int']
        fg_indexes = param_info['fg_indexes']
        tmp_large_tensor = torch.cat([model.state_dict()[key].flatten() for key in fg_name_shapes.keys()])
        fg_grad_mask = torch.zeros_like(tmp_large_tensor)
        del tmp_large_tensor
        fg_grad_mask[fg_indexes] = 1.
        fg_grad_mask = fg_grad_mask.split([np.cumprod(list(shape))[-1] for shape in fg_name_shapes.values()])
        grad_mask_dict = {}
        for i, key in enumerate(fg_name_shapes.keys()):
            grad_mask_dict[key] = fg_grad_mask[i].view(fg_name_shapes[key])

        if args.tune_no_norm:

            new_grad_mask_dict = {}
            for key in grad_mask_dict:
                if 'norm' not in key:
                    new_grad_mask_dict[key] = grad_mask_dict[key]
            # new_grad_mask_dict = {key: grad_mask_dict[key] for key in grad_mask_dict if 'norm' not in key}
            # grad_mask_dict = new_grad_mask_dict

            bias_names = [key for key in bias_names if 'norm' not in key]

        if args.tune_no_fg:
            grad_mask_dict = {}

        loss_scaler = MaskScalerTakeModelOptCGFG(fully_ft_list=bias_names,
                                                 partial_ft_dict=grad_mask_dict,
                                                 freezed_state_dict=freezed_state_dict)

        if args.freeze_selected:
            freezed_keys = list(set(model.state_dict().keys()) - set(bias_names) - set(grad_mask_dict.keys()))
            model.freeze_selected_params(freezed_keys)

    lr_scheduler, _ = create_scheduler(args, optimizer)

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)

    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    # save config for later experiments
    with open(output_dir / "config.yaml", 'w') as f:
        f.write(args_text)

    retrain_config = None
    if args.mode == 'retrain' and "RETRAIN" in cfg:
        retrain_config = {'visual_prompt_dim':cfg.RETRAIN.VISUAL_PROMPT_DIM,'lora_dim':cfg.RETRAIN.LORA_DIM,'adapter_dim':cfg.RETRAIN.ADAPTER_DIM,'prefix_dim':cfg.RETRAIN.PREFIX_DIM,}

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device,  mode = args.mode, retrain_config=retrain_config,is_visual_prompt_tuning=args.is_visual_prompt_tuning,is_adapter=args.is_adapter,is_LoRA=args.is_LoRA,is_prefix=args.is_prefix)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    if args.get_grad_cg:
        train_one_epoch_get_grad_cg(
            model, criterion, data_loader_train,
            optimizer, device, 0, loss_scaler,
            args.clip_grad, model_ema, mixup_fn,
            amp=args.amp, teacher_model=teacher_model,
            teach_loss=teacher_loss,
            choices=choices, mode=args.mode, retrain_config=retrain_config,
            is_visual_prompt_tuning=args.is_visual_prompt_tuning, is_adapter=args.is_adapter,
            is_LoRA=args.is_LoRA, is_prefix=args.is_prefix, dataset=args.data_set, nb_classes=args.nb_classes,
        )
        return

    print("Start training")

    # return
    start_time = time.time()
    max_accuracy = 0.0

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        if args.loss_type == 'orth':
            train_stats = train_one_epoch_orth(
                model, criterion, data_loader_train,
                optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn,
                amp=args.amp, teacher_model=teacher_model,
                teach_loss=teacher_loss,
                choices=choices, mode=args.mode, retrain_config=retrain_config,
                is_visual_prompt_tuning=args.is_visual_prompt_tuning, is_adapter=args.is_adapter, is_LoRA=args.is_LoRA,
                is_prefix=args.is_prefix, scaler=args.scaler, orth_loss_sigma=args.orth_loss_sigma
            )

        elif args.loss_type == 'corrv2':
            train_stats = train_one_epoch_corrv2(
                model, criterion, data_loader_train,
                optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn,
                amp=args.amp, teacher_model=teacher_model,
                teach_loss=teacher_loss,
                choices=choices, mode=args.mode, retrain_config=retrain_config,
                is_visual_prompt_tuning=args.is_visual_prompt_tuning, is_adapter=args.is_adapter, is_LoRA=args.is_LoRA,
                is_prefix=args.is_prefix, scaler=args.scaler, orth_loss_sigma=args.orth_loss_sigma
            )

        elif args.loss_type == 'coral':
            imn_features = torch.load("imn_averaged_features_b256.pth", map_location='cpu')
            auxiliary_loss = utils.CORAL

            train_stats = train_one_epoch_auxiliary_last(
                model, criterion, data_loader_train,
                optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn,
                amp=args.amp, teacher_model=teacher_model,
                teach_loss=teacher_loss,
                choices=choices, mode=args.mode, retrain_config=retrain_config,
                is_visual_prompt_tuning=args.is_visual_prompt_tuning, is_adapter=args.is_adapter, is_LoRA=args.is_LoRA,
                is_prefix=args.is_prefix, scaler=args.scaler, sigma=args.orth_loss_sigma, imn_features=imn_features,
                auxiliary_loss=auxiliary_loss
            )

        elif args.loss_type == 'coral_pl':
            imn_features = torch.load("imn_averaged_features_b256_10k_per_layer.pth", map_location='cpu')
            auxiliary_loss = utils.CORAL

            train_stats = train_one_epoch_auxiliary_pl(
                model, criterion, data_loader_train,
                optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn,
                amp=args.amp, teacher_model=teacher_model,
                teach_loss=teacher_loss,
                choices=choices, mode=args.mode, retrain_config=retrain_config,
                is_visual_prompt_tuning=args.is_visual_prompt_tuning, is_adapter=args.is_adapter, is_LoRA=args.is_LoRA,
                is_prefix=args.is_prefix, scaler=args.scaler, sigma=args.orth_loss_sigma, imn_features=imn_features,
                auxiliary_loss=auxiliary_loss
            )

        elif args.loss_type == 'mmd':
            imn_features = torch.load("imn_averaged_features_b256.pth", map_location='cpu')
            auxiliary_loss = utils.mmd

            train_stats = train_one_epoch_auxiliary_last(
                model, criterion, data_loader_train,
                optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn,
                amp=args.amp, teacher_model=teacher_model,
                teach_loss=teacher_loss,
                choices=choices, mode=args.mode, retrain_config=retrain_config,
                is_visual_prompt_tuning=args.is_visual_prompt_tuning, is_adapter=args.is_adapter, is_LoRA=args.is_LoRA,
                is_prefix=args.is_prefix, scaler=args.scaler, sigma=args.orth_loss_sigma, imn_features=imn_features,
                auxiliary_loss=auxiliary_loss
            )

        else:
            train_stats = train_engine(
                model, criterion, data_loader_train,
                optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn,
                amp=args.amp, teacher_model=teacher_model,
                teach_loss=teacher_loss,
                choices=choices, mode = args.mode, retrain_config=retrain_config,
                is_visual_prompt_tuning=args.is_visual_prompt_tuning,is_adapter=args.is_adapter,is_LoRA=args.is_LoRA,
                is_prefix=args.is_prefix, scaler=args.scaler
            )

        lr_scheduler.step(epoch)

        if epoch % args.val_interval == 0 or epoch >= args.epochs-10:
            test_stats = test_engine(data_loader_val, model, device, amp=args.amp, choices=choices, mode = args.mode, retrain_config=retrain_config,is_visual_prompt_tuning=args.is_visual_prompt_tuning,is_adapter=args.is_adapter,is_LoRA=args.is_LoRA,is_prefix=args.is_prefix)
            print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
            max_accuracy = max(max_accuracy, test_stats["acc1"])
            # print(f'Max accuracy: {max_accuracy:.2f}%')
            print(
                f"[{args.exp_name}] Max accuracy on the {args.data_set} dataset {len(dataset_val)} with ({args.opt}, {args.lr}, {args.weight_decay}), {max_accuracy:.2f}%")

            # if max_accuracy == test_stats["acc1"] and args.save_best:
            #     if args.output_dir:
            #         checkpoint_paths = [output_dir / 'best_checkpoint.pth']
            #         for checkpoint_path in checkpoint_paths:
            #             utils.save_on_master({
            #                 'model': model_without_ddp.state_dict(),
            #                 # 'optimizer': optimizer.state_dict(),
            #                 # 'lr_scheduler': lr_scheduler.state_dict(),
            #                 'epoch': epoch,
            #                 # 'scaler': loss_scaler.state_dict(),
            #                 'args': args,
            #             }, checkpoint_path)

            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                        **{f'test_{k}': v for k, v in test_stats.items()},
                        'epoch': epoch,
                        'n_parameters': n_parameters}

            if args.output_dir and utils.is_main_process():
                with (output_dir / "log.txt").open("a") as f:
                    f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    parser = argparse.ArgumentParser('AutoFormer training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
