# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------

import argparse
import datetime
import json
import numpy as np
import os
import time
import sys

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

import timm

# assert timm.__version__ == "0.3.2" # version check
from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

import util.lr_decay as lrd
import util.misc as misc
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler

import model_ViT

from engine_finetune import train_one_epoch, evaluate
from prepare_dataset import *

from util.optimizer import build_optimizer
import baseline_models as bm


def get_args_parser():
    parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
    parser.add_argument('--batch_size', default=64, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')

    parser.add_argument('--epochs', default=1, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

    # Model parameters
    parser.add_argument('--model', default='vit_base_patch16_enc_nodp', type=str, metavar='MODEL',
                        help='Name of model to train')

    parser.add_argument('--input_size', default=224, type=int,
                        help='images input size')

    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')

    # Optimizer parameters
    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')

    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--layer_decay', type=float, default=0.75,
                        help='layer-wise lr decay from ELECTRA/BEiT')

    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')

    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
                        help='epochs to warmup LR')

    # Augmentation parameters
    parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
                        help='Color jitter factor (enabled only when not using Auto/RandAug)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.1,
                        help='Label smoothing (default: 0.1)')

    # * Random Erase params
    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')

    # * Mixup params
    parser.add_argument('--mixup', type=float, default=0.8,
                        help='mixup alpha, mixup enabled if > 0.')
    parser.add_argument('--cutmix', type=float, default=1.0,
                        help='cutmix alpha, cutmix enabled if > 0.')
    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup_prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup_mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

    # * Finetuning params
    parser.add_argument('--finetune', default='',
                        help='finetune from checkpoint')
    parser.add_argument('--global_pool', action='store_true')
    parser.set_defaults(global_pool=True)
    parser.add_argument('--cls_token', action='store_false', dest='global_pool',
                        help='Use class token instead of global pool for classification')

    # Dataset parameters
    parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
                        help='dataset path')
    parser.add_argument('--nb_classes', default=20, type=int,
                        help='number of the classification types')

    parser.add_argument('--output_dir', default='./checkpoint/global/',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='./checkpoint/logs/',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true',
                        help='Perform evaluation only')
    parser.add_argument('--dist_eval', action='store_true', default=True,
                        help='Enabling distributed evaluation (recommended during training for faster monitor')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')
    
    # arguments from main.py
    parser.add_argument('-rd', '--rounds', default=200, type=int,
                        help='The number of rounds for federated learning')
    parser.add_argument('-tt', '--training_times', default=3, type=int,
                        help='The bounded number of times of training on each worker')
    parser.add_argument('-sb', '--stale_bound', default=5, type=int,
                        help='The bounded number of train round gap between worker stored local model and visitor model')
    parser.add_argument('-nw', '--num_total_workers', default=100, type=int,
                        help='The total number of workers in the network')
    parser.add_argument('-nm', '--num_models', default=1, type=int,
                        help='The number of autoencoders being trained') 
    parser.add_argument('-le', '--num_of_local_epochs', default=5, type=int,
                        help='The number of epochs in local training')
    parser.add_argument('-c', '--connectivity', default=0.15, type=float,
                        help='The connectivity of network')  
    parser.add_argument('-hcp', '--highest_cp', default=5, type=int,
                        help='The highest computing power for workers') 
    parser.add_argument('-bd', '--base_depth', default=1, type=int,  
                        help='The depth of base transformer')
    
    parser.add_argument('-sp', '--save_path', default='./checkpoint/', type=str, help='checkpoint save path')
    parser.add_argument('-gp', '--graph_path', default='./graph/network_G.adjlist', type=str, help='network graph path')
    parser.add_argument('-dsp', '--data_split_path', default='./graph/data_split.pkl', type=str, help='network data split path') 
    parser.add_argument('-cpp', '--computing_power_path', default='./graph/computing_power.pkl', type=str, help='network computing powers of workers path') 
    parser.add_argument('-swp', '--starting_workers_path', default='./graph/starting_workers.pkl', type=str, help='network starting workers path') 

    parser.add_argument('-d', '--dataset', default=4, type=int, 
                        help='The id of dataset to use: 0 - CIFAR10; 1 - CIFAR100; 2 - Food101; 3 - ImageNet; 4 - Mini-ImageNet; 5 - Road-Sign, 6 - Mini-INAT;')
    parser.add_argument('-ra', '--ratio', default=1, type=float,
                        help='The ratio of labelled images')
    parser.add_argument('-samp', '--sampling', default="iid", 
                        help='The way of samping, iid or dir')
    parser.add_argument('--alpha', default=1e-1, type=float, 
                        help='The required parameter for dir sampling, which decides the statistical heterogenity')
    parser.add_argument('-ri', '--record_interval', default=100, type=int, 
                        help='The interval of saving checkpoint')
    parser.add_argument('-m', '--mode', default=2, type=int, 
                        help='The mode of next worker finding algorithm: 0 - the beginning algorithm; 1 - random; 2 - new one')
    parser.add_argument('-agg', '--agg', default=3, type=int, 
                        help='The mode of aggregation: 0 - average weights; 1 - data volume weights; 2 - round weights; 3 - our weights')
    parser.add_argument('-bl', '--baseline', default=0, type=int, 
                        help='The baselines: 0 - our algorithm; 1 - FedMAE, 2 - GossipMAE, 3 - DecenCNN')
    
    # Model Size Study
    parser.add_argument('-sc', '--scenario', default=0, type=int, 
                        help='The scenarios: 0 - Centralized; 1 - Federated, 2 - Single Client')
    parser.add_argument('-dp', '--depths_range', default=10, type=int, 
                        help='The number of model size options')
    parser.add_argument('-scp', '--start_model_ckpt_path', default='./checkpoint_cen/', type=str, help='The save path of start model checkpoints (Used for ensuring same starting weights)')
    parser.add_argument('-logp', '--logs_save_path', default='./model_size_logs/', type=str, help='The save path of training logs')

    parser.add_argument('-p', '--phase', default="pretrain",  
                        help='specify the codes to: pretrain, finetune')
    parser.add_argument('-ftd', '--ft_depth', default=5, type=int,  
                        help='The depth of model in finetuning')

    return parser

def train_VIT(main_args, worker_ID, dataset_train, depths, r_eps, num_classes, dataset_val=None, eval_model=False, load_path=None, save_path=None):
    
    args = get_args_parser().parse_args()
    args.batch_size = torch.cuda.device_count() * args.batch_size
    if worker_ID:
        args.output_dir = '%sworker/%s' % (main_args.save_path, worker_ID)
        args.log_dir = '%sworker/%s' % (main_args.save_path, worker_ID)
    else:
        args.output_dir = '%sglobal/' % main_args.save_path
        args.log_dir = '%sglobal/' % main_args.save_path

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    #print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    if True:  # args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        if dataset_train:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
            print("Sampler_train = %s" % str(sampler_train))
        if dataset_val:
            if args.dist_eval:
                if len(dataset_val) % num_tasks != 0:
                    print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                        'This will slightly alter validation results as extra duplicate entries are added to achieve '
                        'equal num of samples per-process.')
                sampler_val = torch.utils.data.DistributedSampler(
                    dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)  # shuffle=True to reduce monitor bias
            else:
                sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        if dataset_train:
            sampler_train = torch.utils.data.RandomSampler(dataset_train)
        if dataset_val:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    if global_rank == 0 and args.log_dir is not None and not args.eval:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    if dataset_train:
        data_loader_train = torch.utils.data.DataLoader(
            dataset_train, sampler=sampler_train,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=True,
        )

    if dataset_val:
        data_loader_val = torch.utils.data.DataLoader(
            dataset_val, sampler=sampler_val,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False
        )

    
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        print("Mixup is activated!")
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=num_classes)

    
    model = model_ViT.__dict__[args.model](
        depth=depths[0],
        embed_dim=depths[-1],
        num_classes=num_classes,
        drop_path_rate=args.drop_path,
        global_pool=args.global_pool,
    )


    # if args.finetune and not args.eval:
    if load_path and os.path.exists(load_path) and not args.eval:
        checkpoint = torch.load(load_path, map_location='cpu')

        print("Load pre-trained checkpoint from: %s" % load_path)
        checkpoint_model = checkpoint
        if 'model' in checkpoint:
            checkpoint_model = checkpoint['model']
        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)
    model_without_ddp = model

    n_parameters = sum(p.numel() for p in model.parameters())

    print('number of params (M): %.2f' % (n_parameters / 1.e6))

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    
    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)


    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if torch.cuda.device_count() > 1:
        model = torch.nn.parallel.DataParallel(model)
        model_without_ddp = model.module


    # build optimizer with layer-wise lr decay (lrd)
    param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
        no_weight_decay_list=model_without_ddp.no_weight_decay(),
        layer_decay=args.layer_decay
    )
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr)

    loss_scaler = NativeScaler()
    criterion = torch.nn.CrossEntropyLoss()

    if mixup_fn is not None:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing > 0.:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    print("criterion = %s" % str(criterion))


    if eval_model:
        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        test_loss = float("%.5f" % test_stats['loss'])
        test_acc = float("%.2f" % test_stats['acc1'])
        return test_loss, test_acc

    if dataset_train:

        print(f"Start training for {r_eps} epochs")
        start_time = time.time()
        max_accuracy = 0.0
        model.train()

        best_loss = float('inf')
        for epoch in range(args.start_epoch, args.start_epoch+r_eps):

            if args.distributed:
                data_loader_train.sampler.set_epoch(epoch)
            train_stats = train_one_epoch(
                model, criterion, data_loader_train,
                optimizer, device, epoch, loss_scaler,
                args.clip_grad, mixup_fn,
                log_writer=log_writer,
                args=args
            )

            if args.output_dir:
                if train_stats['loss'] < best_loss:
                    best_loss = train_stats['loss']
                    misc.save_model_pretrain(
                        args=args, epoch=epoch, model=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, best=True)
                if epoch % args.record_interval == 0 or epoch + 1 == args.epochs:
                    misc.save_model_pretrain(
                        args=args, epoch=epoch, model=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)

            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                            'epoch': epoch,}

            if args.output_dir and misc.is_main_process():
                if log_writer is not None:
                    log_writer.flush()
                with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                    f.write(json.dumps(log_stats) + "\n")  


        if dataset_val:
            test_stats = evaluate(data_loader_val, model, device)
            print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%")

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))

    new_checkpoint = {
        'model': model_without_ddp.state_dict(),
        'optimizer': optimizer.state_dict(),
        'loss_scaler': loss_scaler.state_dict()
    }

    if save_path:
        misc.save_on_master(new_checkpoint, save_path)

    if dataset_train:
        return new_checkpoint, train_stats['loss'], n_parameters / 1.e6
    else:
        return new_checkpoint


def finetune_VIT(main_args, super_train_idxs, mean, std, dataset_name, label_rate, depths,load_path=None, save_path=None, res_out=False):
    
    args = get_args_parser().parse_args()
    args.batch_size = torch.cuda.device_count() * args.batch_size
    args.output_dir = '%sfinetune/' % main_args.save_path
    args.log_dir = '%sfinetune/' % main_args.save_path
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    #print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True
    
    train_tranform, val_transform = init_finetune_transform(args, mean, std)
    dataset_train, dataset_val, num_classes = init_dataset(train_tranform, val_transform, dataset_name)
    if super_train_idxs[0] == -1:
        if label_rate < 1:
            super_train_idxs, _ = divide_dataset(dataset_train, label_rate, dataset_name)
            dataset_train = get_super_dataset(dataset_train, super_train_idxs)
    else:
        if label_rate < 1:
            dataset_train = get_super_dataset(dataset_train, super_train_idxs)
    args.nb_classes = num_classes


    if True:  # args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                      'This will slightly alter validation results as extra duplicate entries are added to achieve '
                      'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)  # shuffle=True to reduce monitor bias
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    if global_rank == 0 and args.log_dir is not None and not args.eval:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None


    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False
    )

    
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        print("Mixup is activated!")
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)

    
    
    model = model_ViT.__dict__[args.model](
        depth=depths[0],
        embed_dim=depths[-1],
        num_classes=args.nb_classes,
        drop_path_rate=args.drop_path,
        global_pool=args.global_pool,
    )


    # if args.finetune and not args.eval:
    if load_path and os.path.exists(load_path) and not args.eval:
        checkpoint = torch.load(load_path, map_location='cpu')

        print("Load pre-trained checkpoint from: %s" % load_path)
        checkpoint_model = checkpoint
        if 'model' in checkpoint:
            checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        interpolate_pos_embed(model, checkpoint_model)

        # load pre-trained model
        msg = model.load_state_dict(checkpoint_model, strict=False)
        #print(msg)
        

        # manually initialize fc layer
        trunc_normal_(model.head.weight, std=2e-5)
    
    # model = bm.create_backbone('res18-origin', num_classes=args.nb_classes)
    # if load_path and os.path.exists(load_path):
    #     print("Load pre-trained checkpoint from: %s" % load_path)
    #     checkpoint = torch.load(load_path, map_location='cpu')
    #     checkpoint_model = checkpoint
    #     if 'model' in checkpoint:
    #         checkpoint_model = checkpoint['model']
    #     model.load_state_dict({k[9:]:v for k, v in checkpoint_model['net'].items() if k.startswith('backbone.')}, strict=False)
    #     #model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    #model_without_ddp = model
    model_without_ddp = model
    n_parameters = sum(p.numel() for p in model.parameters())

    # print("Model = %s" % str(model_without_ddp))
    print('number of params (M): %.2f' % (n_parameters / 1.e6))

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    
    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)


    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if torch.cuda.device_count() > 1:
        model = torch.nn.parallel.DataParallel(model)
        model_without_ddp = model.module

    # if args.distributed:
    #     head = torch.nn.parallel.DistributedDataParallel(model.head, device_ids=[args.gpu])
    #     head_without_ddp = head.module

    # if torch.cuda.device_count() > 1:
    #     head = torch.nn.parallel.DataParallel(model.head)
    #     head_without_ddp = head.module

    #optimizer = build_optimizer(model)

    # build optimizer with layer-wise lr decay (lrd)
    param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
        no_weight_decay_list=model_without_ddp.no_weight_decay(),
        layer_decay=args.layer_decay
    )
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr)

    loss_scaler = NativeScaler()
    criterion = torch.nn.CrossEntropyLoss()

    if mixup_fn is not None:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing > 0.:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    print("criterion = %s" % str(criterion))

    #misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        exit(0)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    model.train()
    # model.head.train()

    test_losses = []
    test_accs = []

    for epoch in range(args.start_epoch, args.epochs):

        # randomClientIDs = random.sample(clientIDlist, m_args.num_of_Clients)

        # if os.path.exists(load_path):
        #     checkpoint = torch.load(load_path, map_location='cpu')
        #     checkpoint_model = checkpoint['model']
        #     interpolate_pos_embed(model, checkpoint_model)
        #     model.load_state_dict(checkpoint_model, strict=False)
        #     print("Found federated encoder weights!")

        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        train_stats = train_one_epoch(
            model, criterion, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, mixup_fn,
            log_writer=log_writer,
            args=args
        )
        if args.output_dir:
            if torch.cuda.device_count() > 1:
                misc.save_model(
                    args=args, model=model.module, model_without_ddp=model_without_ddp, optimizer=optimizer,
                    loss_scaler=loss_scaler, epoch=epoch)
            else:
                misc.save_model(
                    args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                    loss_scaler=loss_scaler, epoch=epoch)

        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        if test_stats["acc1"] > max_accuracy:
            if torch.cuda.device_count() > 1:
                misc.save_model(
                    args=args, model=model.module, model_without_ddp=model_without_ddp, optimizer=optimizer,
                    loss_scaler=loss_scaler, epoch=epoch, best=True)
            else:
                misc.save_model(
                    args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                    loss_scaler=loss_scaler, epoch=epoch, best=True)
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        if log_writer is not None:
            log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
            log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
            log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                        **{f'test_{k}': v for k, v in test_stats.items()},
                        'epoch': epoch,
                        'n_parameters': n_parameters}

        if args.output_dir and misc.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

        test_losses.append(float("%.5f" % test_stats['loss']))
        test_accs.append(float("%.2f" % test_stats['acc1']))
        
        """
        encoder_state_dict = {}
        for param_tensor in model.state_dict():
            if (not "head" in param_tensor) and (not "fc_norm" in param_tensor):
                encoder_state_dict[param_tensor] = model.state_dict()[param_tensor]


        torch.save(encoder_state_dict, encoder_save_path)
        

        modelList = []

        for clientIDi in randomClientIDs:
            modelList += [pretrain_MAE(clientIDi, unsupervised_datasets[clientIDi])] 

        federatedModelWeightUpdate(randomClientIDs)
        """
        
    test_stats = evaluate(data_loader_val, model, device)
    print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

    new_checkpoint = {
        'model': model_without_ddp.state_dict(),
        'optimizer': optimizer.state_dict(),
        'loss_scaler': loss_scaler.state_dict()
    }

    if save_path and os.path.exists(save_path):
        misc.save_on_master(new_checkpoint, save_path)

    if res_out:
        return test_accs, test_losses

def train_CNN(main_args, worker_ID, dataset_train, depth, r_eps, num_classes, dataset_val=None, eval_model=False, load_path=None, save_path=None):
    
    args = get_args_parser().parse_args()
    args.batch_size = torch.cuda.device_count() * args.batch_size
    if worker_ID:
        args.output_dir = '%sworker/%s' % (main_args.save_path, worker_ID)
        args.log_dir = '%sworker/%s' % (main_args.save_path, worker_ID)
    else:
        args.output_dir = '%sglobal/' % main_args.save_path
        args.log_dir = '%sglobal/' % main_args.save_path

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    #print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    if True:  # args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        if dataset_train:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
            print("Sampler_train = %s" % str(sampler_train))
        if dataset_val:
            if args.dist_eval:
                if len(dataset_val) % num_tasks != 0:
                    print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                        'This will slightly alter validation results as extra duplicate entries are added to achieve '
                        'equal num of samples per-process.')
                sampler_val = torch.utils.data.DistributedSampler(
                    dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)  # shuffle=True to reduce monitor bias
            else:
                sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        if dataset_train:
            sampler_train = torch.utils.data.RandomSampler(dataset_train)
        if dataset_val:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    if global_rank == 0 and args.log_dir is not None and not args.eval:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    if dataset_train:
        data_loader_train = torch.utils.data.DataLoader(
            dataset_train, sampler=sampler_train,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=True,
        )

    if dataset_val:
        data_loader_val = torch.utils.data.DataLoader(
            dataset_val, sampler=sampler_val,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False
        )

    model = bm.create_backbone(name='res%s-origin' % depth, num_classes=num_classes)

    # if args.finetune:
    if load_path and os.path.exists(load_path):
        checkpoint = torch.load(load_path, map_location='cpu')

        print("Load pre-trained checkpoint from: %s" % load_path)
        checkpoint_model = checkpoint
        if 'model' in checkpoint:
            checkpoint_model = checkpoint['model']
        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)
    model_without_ddp = model

    n_parameters = sum(p.numel() for p in model.parameters())

    print('number of params (M): %.2f' % (n_parameters / 1.e6))

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    
    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)


    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if torch.cuda.device_count() > 1:
        model = torch.nn.parallel.DataParallel(model)
        model_without_ddp = model.module

    # Optimizer
    optimizer = torch.optim.SGD(model_without_ddp.parameters(), lr=args.lr, momentum=0.9, weight_decay=0)

    criterion = torch.nn.CrossEntropyLoss()

    if eval_model and dataset_val:
        accuracy, test_loss, _ = eval_CNN(model, data_loader_val, device, criterion)
        print("Accuracy of the network on the %s test images: %.2f%%" % (len(dataset_val), accuracy))
        test_loss = float("%.5f" % test_loss)
        accuracy = float("%.2f" % accuracy)
        return test_loss, accuracy

    if dataset_train:

        # define lr scheduler
        lr_scheduler = LR_Scheduler(optimizer, warmup_epochs=0, warmup_lr=args.min_lr, num_epochs=r_eps, base_lr=args.lr, final_lr=args.lr, iter_per_epoch=len(data_loader_train))

        # Train
        model.train()
        print("Training will continue for %s epochs" % r_eps)
        for _ in range(r_eps):
            train_loss = 0
            correct = 0
            total = 0
            model.train()
            for batch_idx, (inputs, targets) in enumerate(data_loader_train):
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                lr = lr_scheduler.step()
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                progress_bar(batch_idx, len(data_loader_train), 'Loss: %.3f | Acc: %.3f | LR: %.3f'
                    % (train_loss/(batch_idx+1), 100.*correct/total, optimizer.param_groups[0]['lr']))

        if dataset_val:
            accuracy, test_loss, _ = eval_CNN(model, data_loader_val, device, criterion)
            print("Accuracy of the network on the %s test images: %.2f%%" % (len(dataset_val), accuracy))

    new_checkpoint = {
        'model': model_without_ddp.state_dict(),
        'optimizer': optimizer.state_dict(),
    }

    if save_path:
        misc.save_on_master(new_checkpoint, save_path)

    if dataset_train:
        return new_checkpoint, train_loss/len(data_loader_train), n_parameters / 1.e6
    else:
        return new_checkpoint


def finetune_CNN(main_args, super_train_idxs, mean, std, dataset_name, label_rate, depth=18, load_path=None, save_path=None):
    
    args = get_args_parser().parse_args()
    args.batch_size = torch.cuda.device_count() * args.batch_size
    args.output_dir = '%sfinetune/' % main_args.save_path
    args.log_dir = '%sfinetune/' % main_args.save_path
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True
    
    train_transform, val_transform = init_linprobe_transform(mean, std)
    dataset_train, dataset_val, num_classes = init_dataset(train_transform, val_transform, dataset_name)
    if super_train_idxs[0] == -1:
        if label_rate < 1:
            super_train_idxs, _ = divide_dataset(dataset_train, label_rate, dataset_name)
            dataset_train = get_super_dataset(dataset_train, super_train_idxs)
    else:
        if label_rate < 1:
            dataset_train = get_super_dataset(dataset_train, super_train_idxs)
    args.nb_classes = num_classes


    if True:  # args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                      'This will slightly alter validation results as extra duplicate entries are added to achieve '
                      'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)  # shuffle=True to reduce monitor bias
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    if global_rank == 0 and args.log_dir is not None and not args.eval:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False
    )

    # Model definitions
    model = bm.create_backbone(name='res%s-origin' % depth, num_classes=num_classes)
    # model = bm.create_backbone(name='res%s-origin' % depth, num_classes=100)
    # classifier = torch.nn.Linear(in_features=model.output_dim, out_features=num_classes, bias=True)

    # Load model
        # if args.finetune and not args.eval:
    if load_path and os.path.exists(load_path) and not args.eval:
        checkpoint = torch.load(load_path, map_location='cpu')

        print("Load pre-trained checkpoint from: %s" % load_path)
        checkpoint_model = checkpoint
        if 'model' in checkpoint:
            checkpoint_model = checkpoint['model']
        # model.load_state_dict({k[9:]:v for k, v in checkpoint_model.items() if k.startswith('backbone.')}, strict=False)

        # load pre-trained model
        msg = model.load_state_dict(checkpoint_model, strict=False)

    # model.linear = classifier
    model = model.to(device)
    model_without_ddp = model
    n_parameters = sum(p.numel() for p in model.parameters())

    # print("Model = %s" % str(model_without_ddp))
    print('number of params (M): %.2f' % (n_parameters / 1.e6))

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
    
    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)


    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if torch.cuda.device_count() > 1:
        model = torch.nn.parallel.DataParallel(model)
        model_without_ddp = model.module

    # Optimizer
    optimizer = torch.optim.SGD(model_without_ddp.parameters(), lr=args.lr, momentum=0.9, weight_decay=0)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(optimizer, warmup_epochs=args.warmup_epochs, warmup_lr=args.min_lr, num_epochs=args.epochs, base_lr=args.blr, final_lr=args.lr, iter_per_epoch=len(data_loader_train))

    test_losses = []
    test_accuracies = []
    # Train
    model.train()
    # classifier.train()
    criterion = torch.nn.CrossEntropyLoss()
    for _ in range(args.epochs):
        train_loss = 0
        correct = 0
        total = 0
        model.train()
        for batch_idx, (inputs, targets) in enumerate(data_loader_train):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            # outputs = classifier(features)
            loss.backward()
            optimizer.step()
            lr = lr_scheduler.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            progress_bar(batch_idx, len(data_loader_train), 'Loss: %.3f | Acc: %.3f | LR: %.3f'
                % (train_loss/(batch_idx+1), 100.*correct/total, optimizer.param_groups[0]['lr']))

        accuracy, test_loss, _ = eval_CNN(model, data_loader_val, device, criterion)
        test_losses.append(test_loss)
        test_accuracies.append(accuracy)

    # Test
    # accuracy, test_loss, _ = eval_CNN(model, data_loader_val, device, criterion)

    # model.eval()
    # # classifier.eval()
    # correct, total, test_loss = 0, 0, 0.0
    # print("\n")
    # with torch.no_grad():
    #     for batch_idx, (inputs, targets) in enumerate(data_loader_val):
    #         inputs, targets = inputs.to(device), targets.to(device)
    #         outputs = model(inputs)
    #         loss = criterion(outputs, targets)
    #         # outputs = classifier(model(inputs.to(device)))
    #         test_loss += loss.item()
    #         _, predicted = outputs.max(1)
    #         total += targets.size(0)
    #         correct += predicted.eq(targets).sum().item()
    #         progress_bar(batch_idx, len(data_loader_val), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
    #             % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    # result = {
    #     'correct': correct,
    #     'total': total,
    #     'accuracy': correct/total,
    #     'test_loss': test_loss
    # }

    return test_accuracies, test_losses

def eval_CNN(model, data_loader_val, device, criterion):
    model.eval()
    # classifier.eval()
    correct, total, test_loss = 0, 0, 0.0
    print("\n")
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader_val):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            # outputs = classifier(model(inputs.to(device)))
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            progress_bar(batch_idx, len(data_loader_val), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

        result = {
            'correct': correct,
            'total': total,
            'accuracy': correct/total,
            'test_loss': test_loss
        }
    return correct/total*100, float('%.4f' % (test_loss / (batch_idx+1))), result



# LR Scheduler
class LR_Scheduler(object):
    def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False):
        self.base_lr = base_lr
        self.constant_predictor_lr = constant_predictor_lr
        warmup_iter = iter_per_epoch * warmup_epochs
        warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
        decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
        cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter))
        
        self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
        self.optimizer = optimizer
        self.iter = 0
        self.current_lr = 0
    def step(self):
        for param_group in self.optimizer.param_groups:

            if self.constant_predictor_lr and param_group['name'] == 'predictor':
                param_group['lr'] = self.base_lr
            else:
                lr = param_group['lr'] = self.lr_schedule[self.iter]
        
        self.iter += 1
        self.current_lr = lr
        return lr
    def get_lr(self):
        return self.current_lr

######### Progress bar #########
term_width = 150 
TOTAL_BAR_LENGTH = 30.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.
    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')
    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time
    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)
    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')
    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))
    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)
    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f
