"""
DDP模板
"""

import datetime
import os
import os.path as osp
import time
import warnings

import torch
import torch.utils.data
import torchvision
from torch import nn
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
from mmengine.config import Config

import sys
from pathlib import Path

FILE = Path(__file__).resolve()
ROOT = FILE.parents[1]  # root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH

import transforms
from utils import torch_utils as utils
from sampler import RASampler
from datasets import get_dataset
from models import load_model
from utils.misc import print_args, Logger
from utils.yolo_utils import init_seeds
from loss import get_loss_fn
from loss.center_loss import CenterLoss

best_acc1 = 0

LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))


def train_one_epoch(model, criterion, criterion_cent, optimizer, optimizer_centloss,
                    data_loader, device, epoch, args, model_ema=None, scaler=None):
    model.train()
    weight_cent = args.weight_cent
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))

    header = f"Epoch: [{epoch}]"
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
        image, target = image.to(device), target.to(device)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            features, output = model(image)
            loss = criterion(output, target) + criterion_cent(features, target) * weight_cent

        optimizer.zero_grad()
        optimizer_centloss.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            scaler.step(optimizer)
            # by doing so, weight_cent would not impact on the learning of centers
            for param in criterion_cent.parameters():
                param.grad.data *= (1. / weight_cent)
            scaler.step(optimizer_centloss)
            scaler.update()
        else:
            loss.backward()
            if args.clip_grad_norm is not None:
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()
            # by doing so, weight_cent would not impact on the learning of centers
            for param in criterion_cent.parameters():
                param.grad.data *= (1. / weight_cent)
            optimizer_centloss.step()

        if model_ema and i % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)

    # 返回acc和loss
    return metric_logger.acc1.global_avg, metric_logger.loss.global_avg


def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size
    # gather the stats from all processes

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()

    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")

    return metric_logger.acc1.global_avg, metric_logger.loss.global_avg


def main(args):
    global best_acc1

    if RANK in {-1, 0}:  # 在第一个进程中打印信息，并实例化tensorboard
        # 获取当前时间
        timestamp = datetime.datetime.now()

        # 创建输出结果保存路径
        args.output_dir = osp.join(
            args.output_dir,
            args.model,
            args.data_name,
            timestamp.strftime('%Y%m%d/%H%M%S'),
        )
        if not osp.exists(args.output_dir):
            os.makedirs(args.output_dir)

        # 日志文件名
        log_file_name = f"{args.data_name}-{args.model}.log"
        # 将日志在控制台和文件都打印
        sys.stdout = Logger(osp.join(args.output_dir, log_file_name))

        print(f"[INFO] rank: {RANK}")
        print(f"[INFO] result path: {osp.abspath(args.output_dir)}\n", flush=True)

        # 实例化tensorboard
        tb_writer = SummaryWriter(args.output_dir)
        print(f'[INFO] Start Tensorboard with '
              f'"tensorboard --logdir={osp.abspath(args.output_dir)}", '
              f'view at http://localhost:6006/\n')

    utils.init_distributed_mode(args)  # 初始化分布式环境
    print_args(args)

    # 固定种子
    init_seeds(seed=args.seed)

    # Get cpu, gpu or mps device for training.
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"[INFO] Using {device} device")

    print("[INFO] Loading data")
    # TODO: 自行加载数据集
    train_dataset, test_dataset, num_classes = get_dataset(
        data_name=args.data_name,
        data_root=args.data_path,
    )

    # 采样器
    if args.distributed:
        if hasattr(args, "ra_sampler") and args.ra_sampler:
            train_sampler = RASampler(train_dataset, shuffle=True, repetitions=args.ra_reps)
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, shuffle=False)
    else:
        train_sampler = torch.utils.data.RandomSampler(train_dataset)
        test_sampler = torch.utils.data.SequentialSampler(test_dataset)

    collate_fn = None

    # mixup 和 cutmix 数据增强
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)

        def collate_fn(batch):
            return mixupcutmix(*default_collate(batch))

    print("[INFO] Creating data loaders")
    data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        collate_fn=collate_fn,
    )
    data_loader_test = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size//2,
        sampler=test_sampler,
        num_workers=args.workers,
    )

    print("[INFO] Creating model")
    # TODO: 模型架构
    model = load_model(args, num_classes)
    model.to(device)

    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # 损失函数
    criterion = get_loss_fn(args)

    custom_keys_weight_decay = []
    if args.bias_weight_decay is not None:
        custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
    if args.transformer_embedding_decay is not None:
        for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
            custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
    parameters = utils.set_weight_decay(
        model,
        args.weight_decay,
        norm_weight_decay=args.norm_weight_decay,
        custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
    )

    # 优化器
    criterion_cent = CenterLoss(num_classes=num_classes, feat_dim=model.fc.in_features)
    optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=args.lr_cent)
    opt_name = args.opt.lower()
    if opt_name.startswith("sgd"):
        optimizer = torch.optim.SGD(
            parameters,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
            parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
        )
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")

    # 混合精度训练
    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    # 学习率调度器
    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == "steplr":
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
        )
    elif args.lr_scheduler == "exponentiallr":
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
    else:
        raise RuntimeError(
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported."
        )

    # 学习率预热
    if args.lr_warmup_epochs > 0:
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
        )
    else:
        lr_scheduler = main_lr_scheduler

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    # EMA
    model_ema = None
    if args.model_ema:
        # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
        #
        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
        # We consider constant = Dataset_size for a given train_dataset/setup and ommit it. Thus:
        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
        adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
        alpha = 1.0 - args.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)

    # 中断后继续训练
    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        if not args.test_only:
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
        if model_ema:
            model_ema.load_state_dict(checkpoint["model_ema"])
        if scaler:
            scaler.load_state_dict(checkpoint["scaler"])

    # 仅测试权重
    if args.test_only:
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        if model_ema:
            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
        else:
            evaluate(model, criterion, data_loader_test, device=device)
        return

    # 开始训练
    print("[INFO] Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_acc, train_loss = train_one_epoch(
            model,
            criterion, criterion_cent,
            optimizer, optimizer_centloss,
            data_loader, device, epoch, args, model_ema, scaler
        )
        lr_scheduler.step()
        acc1, val_loss = evaluate(model, criterion, data_loader_test, device=device)
        if model_ema:
            acc1, val_loss = evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if RANK in {-1, 0}:
            tag_scalar_dict = {
                "train_loss": train_loss,
                "train_acc": train_acc / 100,
                "test_acc": acc1 / 100,
                "test_loss": val_loss,
                "best_acc1": best_acc1 / 100,
                "lr": optimizer.param_groups[0]["lr"],
            }

            tb_writer.add_scalars(main_tag="loss_acc", tag_scalar_dict=tag_scalar_dict, global_step=epoch)

        if args.output_dir:
            checkpoint = {
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            if model_ema:
                checkpoint["model_ema"] = model_ema.state_dict()
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()

            # save best checkpoint
            if is_best:
                print(f"\n[FEAT] best acc: {best_acc1:.2f}\n")
                # utils.save_on_master(checkpoint, osp.join(args.output_dir, "best_model.pth"))

            # utils.save_on_master(checkpoint, osp.join(args.output_dir, "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"\n[INFO] best acc: {best_acc1:.2f}, err rate: {(100.0 - best_acc1):.2f}")
    print(f"Training time {total_time_str}")


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

    # 配置文件路径（配置文件里面的内容优先）
    parser.add_argument('-c', '--config', type=str, help='Path to the configuration file')

    # 数据集路径
    parser.add_argument("--data-path", default="~/datasets/", type=str, help="dataset path")
    parser.add_argument("--data_name", default="cifar10", type=str, help="dataset name")

    # 模型架构
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument(
        "--model_lib",
        default="torch", type=str,
        choices=["custom", "torch", "timm", "cifar100", "qt"],
        help="model library",
    )

    parser.add_argument("-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")
    parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)")
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument("--wd", "--weight-decay", default=5e-4, type=float, metavar="W", help="weight decay (default: 5e-4)", dest="weight_decay")
    parser.add_argument("--norm-weight-decay", default=None, type=float, help="weight decay for Normalization layers (default: None, same value as --wd)")
    parser.add_argument("--bias-weight-decay", default=None, type=float, help="weight decay for bias parameters of all layers (default: None, same value as --wd)")
    parser.add_argument("--transformer-embedding-decay", default=None, type=float, help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)")

    # 损失函数
    parser.add_argument("--loss_type", default="ce", type=str, help="loss function")
    # CE Loss 的标签平滑参数
    parser.add_argument("--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing")

    # 激活函数
    parser.add_argument("--act_layer", default="relu", type=str, help="activation function")
    
    # 下面两行是数据增强
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")

    parser.add_argument("--lr-scheduler", default="cosineannealinglr", type=str, help="the lr scheduler (default: cosineannealinglr)")
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")

    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")

    parser.add_argument("--print-freq", default=100, type=int, help="print frequency")
    # 模型保存路径
    parser.add_argument("--output-dir", default="./work_dir", type=str, help="path to save outputs")
    # 中断之后恢复训练使用
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")

    parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", action="store_true")

    # 是否只是测试一下给定的权重
    parser.add_argument("--test-only", dest="test_only", help="Only test the model", action="store_true")

    # 混合精度训练
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

    # 分布式训练的参数
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

    # 模型参数是否EMA
    parser.add_argument("--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters")
    parser.add_argument("--model-ema-steps", type=int, default=32, help="the number of iterations that controls how often to update the EMA model (default: 32)")
    parser.add_argument("--model-ema-decay", type=float, default=0.99998, help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)")

    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")

    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument("--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)")

    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
    parser.add_argument("--seed", default=0, type=int)

    # center loss param
    parser.add_argument("--lr_cent", default=0.5, type=float)
    parser.add_argument("--weight_cent", default=0.0005, type=float)

    return parser.parse_args()


if __name__ == "__main__":
    args = get_args_parser()
    cfg = Config(vars(args))

    # 从配置文件中读取配置
    if args.config:
        file_config = Config.fromfile(args.config).to_dict()
    else:
        file_config = {}

    # 将命令行参数和配置文件中的配置合并
    cfg.merge_from_dict(file_config)

    main(cfg)
