# --------------------------------------------------------
# Adapated from BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------'
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os

from pathlib import Path

from timm.models import create_model
from timm.utils import ModelEmaV2
from optim_factory import create_optimizer

from datasets import build_beit_pretraining_dataset
from engine_for_cyclical import train_one_epoch
from utils import NativeScalerWithGradNormCount as NativeScaler
import utils
from scipy import interpolate
import modeling_cyclical


def get_args():
    parser = argparse.ArgumentParser("BEiT pre-training script", add_help=False)
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--epochs", default=15, type=int)
    parser.add_argument("--save_ckpt_freq", default=10, type=int)
    # Model parameters
    parser.add_argument(
        "--model",
        default="deit_base_patch16_224",
        type=str,
        metavar="MODEL",
        help="Name of model to train",
    )
    parser.add_argument("--rel_pos_bias", action="store_true")
    parser.add_argument(
        "--disable_rel_pos_bias", action="store_false", dest="rel_pos_bias"
    )
    parser.set_defaults(rel_pos_bias=True)
    parser.add_argument("--abs_pos_emb", action="store_true")
    parser.set_defaults(abs_pos_emb=False)
    parser.add_argument(
        "--layer_scale_init_value",
        default=0.1,
        type=float,
        help="0.1 for base, 1e-5 for large. set 0 to disable layer scale",
    )

    parser.add_argument(
        "--num_mask_patches",
        default=75,
        type=int,
        help="number of the visual tokens/patches need be masked",
    )
    parser.add_argument("--max_mask_patches_per_block", type=int, default=None)
    parser.add_argument("--min_mask_patches_per_block", type=int, default=16)

    parser.add_argument(
        "--input_size", default=224, type=int, help="images input size for backbone"
    )

    parser.add_argument(
        "--drop_path",
        type=float,
        default=0.1,
        metavar="PCT",
        help="Drop path rate (default: 0.1)",
    )
    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')

    # Optimizer parameters
    parser.add_argument(
        "--opt",
        default="adamw",
        type=str,
        metavar="OPTIMIZER",
        help='Optimizer (default: "adamw"',
    )
    parser.add_argument(
        "--opt_eps",
        default=1e-8,
        type=float,
        metavar="EPSILON",
        help="Optimizer Epsilon (default: 1e-8)",
    )
    parser.add_argument(
        "--opt_betas",
        default=None,
        type=float,
        nargs="+",
        metavar="BETA",
        help="Optimizer Betas (default: None, use opt default)",
    )
    parser.add_argument(
        "--clip_grad",
        type=float,
        default=None,
        metavar="NORM",
        help="Clip gradient norm (default: None, no clipping)",
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.9,
        metavar="M",
        help="SGD momentum (default: 0.9)",
    )
    parser.add_argument(
        "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)"
    )
    parser.add_argument(
        "--weight_decay_end",
        type=float,
        default=None,
        help="""Final value of the
        weight decay. We use a cosine schedule for WD. 
        (Set the same value with args.weight_decay to keep weight decay no change)""",
    )

    parser.add_argument(
        "--lr",
        type=float,
        default=5e-4,
        metavar="LR",
        help="learning rate (default: 5e-4)",
    )
    parser.add_argument(
        "--warmup_lr",
        type=float,
        default=1e-6,
        metavar="LR",
        help="warmup learning rate (default: 1e-6)",
    )
    parser.add_argument(
        "--min_lr",
        type=float,
        default=1e-5,
        metavar="LR",
        help="lower lr bound for cyclic schedulers that hit 0 (1e-5)",
    )

    parser.add_argument(
        "--tri_phase_schedule",
        type=str,
        default=None,
        help="string containing a tuple with phase ratios for warmup and decay. e.g. '(0.05,0.15) means 5% warmup, 80% hold, 15% decay",
    )

    parser.add_argument(
        "--warmup_epochs",
        type=int,
        default=5,
        metavar="N",
        help="epochs to warmup LR, if scheduler supports",
    )
    parser.add_argument(
        "--warmup_steps",
        type=int,
        default=-1,
        metavar="N",
        help="epochs to warmup LR, if scheduler supports",
    )

    # Augmentation parameters
    parser.add_argument(
        "--color_jitter",
        type=float,
        default=0.4,
        metavar="PCT",
        help="Color jitter factor (default: 0.4)",
    )
    parser.add_argument(
        "--train_interpolation",
        type=str,
        default="bicubic",
        help='Training interpolation (random, bilinear, bicubic default: "bicubic")',
    )
    parser.add_argument("--aug_level", default=-1, type=int)


    parser.add_argument(
        "--target_layers", type=str, default="[]", help="target layers (python list)"
    )

    # Dataset parameters
    parser.add_argument(
        "--data_path",
        default="/datasets01/imagenet_full_size/061417/",
        type=str,
        help="dataset path",
    )
    parser.add_argument('--data_set', default='IMNET', choices=['CIFAR100', 'CIFAR10', 'IMNET', 'image_folder', 'tiny_IMNET'],
                        type=str, help='ImageNet dataset path')
    parser.add_argument(
        "--imagenet_default_mean_and_std", default=False, action="store_true"
    )

    parser.add_argument(
        "--output_dir", default="", help="path where to save, empty for no saving"
    )
    parser.add_argument("--log_dir", default=None, help="path where to tensorboard log")
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--resume", default="", help="resume from checkpoint")
    parser.add_argument("--auto_resume", action="store_true")
    parser.add_argument("--no_auto_resume", action="store_false", dest="auto_resume")
    parser.set_defaults(auto_resume=True)

    parser.add_argument("--ema_decay_init", default=0.999, type=float)
    parser.add_argument("--ema_decay", default=0.9998, type=float)
    parser.add_argument("--ema_start_at", default=25000, type=int)

    parser.add_argument(
        "--start_epoch", default=0, type=int, metavar="N", help="start epoch"
    )
    parser.add_argument("--num_workers", default=10, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )
    parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem", help="")
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument(
        "--world_size", default=1, type=int, help="number of distributed processes"
    )
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument("--dist_on_itp", action="store_true")
    parser.add_argument(
        "--dist_url", default="env://", help="url used to set up distributed training"
    )

    parser.add_argument("--seed_model", default=None, type=str, help="seed model")
    parser.add_argument("--model_key", default="model|module", type=str)
    parser.add_argument("--model_prefix", default="", type=str)

    parser.add_argument("--l2_loss", default=False, action="store_true")
    parser.add_argument("--l1_beta", default=0.12, type=float)

    parser.add_argument("--layer_results", default="end", type=str)

    parser.add_argument("--var_w0", default=0., type=float)
    parser.add_argument("--var_w1", default=0., type=float)
    parser.add_argument("--var_margin0", default=0.5, type=float)
    parser.add_argument("--var_margin1", default=0.5, type=float)
    parser.add_argument("--skip_ema_during_lr_decay_for_tri", action="store_true")
    parser.add_argument("--loss_scale", default=-1, type=float)
    parser.add_argument("--ema_annealing_till_end", default=False, action="store_true")
    parser.add_argument("--attn_drop_rate", default=0.0, type=float)
    parser.add_argument("--mask_dropout_prob", default=-1.0, type=float, help="prob of flipping already masked position to unmasked")

    #target_layer_norm_last=True, target_batch_norm=False, target_instance_norm=False
    parser.add_argument("--no_target_layer_norm_last", default=False, action="store_true")
    parser.add_argument("--target_batch_norm", default=False, action="store_true")
    parser.add_argument("--target_instance_norm", default=False, action="store_true")
    parser.add_argument("--post_target_instance_norm", default=False, action="store_true")
    parser.add_argument("--post_target_layer_norm", default=False, action="store_true")
    parser.add_argument("--gp_layer", default=False, action="store_true")
    parser.add_argument("--gumbel_softmax", default=False, action="store_true")
    parser.add_argument('--sinkformer', action='store_true')
    parser.add_argument('--h_sto_trans', default = False,  action='store_true')
    parser.add_argument('--stochastic', default = False,  action='store_true')
    parser.add_argument('--lambda_pretraining', type=float, default=1e-5)

    return parser.parse_args()


def get_model(args):
    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        drop_path_rate=args.drop_path,
        drop_rate=args.drop,
        use_shared_rel_pos_bias=args.rel_pos_bias,
        use_abs_pos_emb=args.abs_pos_emb,
        init_values=args.layer_scale_init_value,
        attn_drop_rate=args.attn_drop_rate,
        gp_layer = args.gp_layer,
        gumbel_softmax = args.gumbel_softmax,
        sinkformer=args.sinkformer,
        h_sto_trans = args.h_sto_trans
    )

    return model


def main(args):
    utils.init_distributed_mode(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    model = get_model(args)

    patch_size = model.patch_embed.patch_size
    print("Patch size = %s" % str(patch_size))
    args.window_size = (
        args.input_size // patch_size[0],
        args.input_size // patch_size[1],
    )
    args.patch_size = patch_size

    if args.seed_model:
        checkpoint = torch.load(args.seed_model, map_location="cpu")
        print("Load ckpt from %s" % args.seed_model)

        checkpoint_model = None
        for model_key in args.model_key.split("|"):
            if model_key in checkpoint:
                checkpoint_model = checkpoint[model_key]
                print("Load state_dict by model_key = %s" % model_key)
                break
        if checkpoint_model is None:
            checkpoint_model = checkpoint
        state_dict = model.state_dict()
        for k in ["head.weight", "head.bias"]:
            if (
                k in checkpoint_model
                and checkpoint_model[k].shape != state_dict[k].shape
            ):
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        all_keys = list(checkpoint_model.keys())
        for key in all_keys:
            if "relative_position_index" in key:
                checkpoint_model.pop(key)

            if "relative_position_bias_table" in key:
                rel_pos_bias = checkpoint_model[key]
                src_num_pos, num_attn_heads = rel_pos_bias.size()
                dst_num_pos, _ = model.state_dict()[key].size()
                dst_patch_shape = model.patch_embed.patch_shape
                if dst_patch_shape[0] != dst_patch_shape[1]:
                    raise NotImplementedError()
                num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
                    dst_patch_shape[1] * 2 - 1
                )
                src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
                dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
                if src_size != dst_size:
                    print(
                        "Position interpolate for %s from %dx%d to %dx%d"
                        % (key, src_size, src_size, dst_size, dst_size)
                    )
                    extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
                    rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]

                    def geometric_progression(a, r, n):
                        return a * (1.0 - r ** n) / (1.0 - r)

                    left, right = 1.01, 1.5
                    while right - left > 1e-6:
                        q = (left + right) / 2.0
                        gp = geometric_progression(1, q, src_size // 2)
                        if gp > dst_size // 2:
                            right = q
                        else:
                            left = q

                    # if q > 1.090307:
                    #     q = 1.090307

                    dis = []
                    cur = 1
                    for i in range(src_size // 2):
                        dis.append(cur)
                        cur += q ** (i + 1)

                    r_ids = [-_ for _ in reversed(dis)]

                    x = r_ids + [0] + dis
                    y = r_ids + [0] + dis

                    t = dst_size // 2.0
                    dx = np.arange(-t, t + 0.1, 1.0)
                    dy = np.arange(-t, t + 0.1, 1.0)

                    print("Original positions = %s" % str(x))
                    print("Target positions = %s" % str(dx))

                    all_rel_pos_bias = []

                    for i in range(num_attn_heads):
                        z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
                        f = interpolate.interp2d(x, y, z, kind="cubic")
                        all_rel_pos_bias.append(
                            torch.Tensor(f(dx, dy))
                            .contiguous()
                            .view(-1, 1)
                            .to(rel_pos_bias.device)
                        )

                    rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)

                    new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
                    checkpoint_model[key] = new_rel_pos_bias

        # interpolate position embedding
        if "pos_embed" in checkpoint_model:
            pos_embed_checkpoint = checkpoint_model["pos_embed"]
            embedding_size = pos_embed_checkpoint.shape[-1]
            num_patches = model.patch_embed.num_patches
            num_extra_tokens = model.pos_embed.shape[-2] - num_patches
            # height (== width) for the checkpoint position embedding
            orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
            # height (== width) for the new position embedding
            new_size = int(num_patches ** 0.5)
            # class_token and dist_token are kept unchanged
            if orig_size != new_size:
                print(
                    "Position interpolate from %dx%d to %dx%d"
                    % (orig_size, orig_size, new_size, new_size)
                )
                extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
                # only the position tokens are interpolated
                pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
                pos_tokens = pos_tokens.reshape(
                    -1, orig_size, orig_size, embedding_size
                ).permute(0, 3, 1, 2)
                pos_tokens = torch.nn.functional.interpolate(
                    pos_tokens,
                    size=(new_size, new_size),
                    mode="bicubic",
                    align_corners=False,
                )
                pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
                new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
                checkpoint_model["pos_embed"] = new_pos_embed

        utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)

    # get dataset
    dataset_train = build_beit_pretraining_dataset(args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        sampler_rank = global_rank
        num_training_steps_per_epoch = (
            len(dataset_train) // args.batch_size // num_tasks
        )

        print("pre-sampler", num_tasks, global_rank, sampler_rank)
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    model.to(device)
    model_without_ddp = model
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("Model = %s" % str(model_without_ddp))
    print("number of params:", n_parameters)

    model_ema = ModelEmaV2(model, decay=args.ema_decay)
    print("Using EMA with decay = %.8f" % args.ema_decay)

    total_batch_size = args.batch_size * utils.get_world_size()
    print("LR = %.8f" % args.lr)
    print("Batch size = %d" % total_batch_size)
    print("Number of training steps = %d" % num_training_steps_per_epoch)
    print(
        "Number of training examples per epoch = %d"
        % (total_batch_size * num_training_steps_per_epoch)
    )

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True
        )
        model_without_ddp = model.module

    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    start_lr_decay_at_step = -1
    if args.tri_phase_schedule is not None:
        from ast import literal_eval
        warmup_phase, decay_phase = literal_eval(args.tri_phase_schedule)
        print("Use tri phase lr schedule!", warmup_phase, decay_phase)
        lr_schedule_values = utils.tri_phase_scheduler(
            args.lr,
            args.min_lr,
            args.epochs,
            num_training_steps_per_epoch,
            warmup_perc=warmup_phase,
            decay_perc=decay_phase,
        )
        if args.skip_ema_during_lr_decay_for_tri:
            start_lr_decay_at_step= (1-decay_phase)*args.epochs*num_training_steps_per_epoch
            print("ema will be skipped after "+str(start_lr_decay_at_step)+" updates")
    else:
        print("Use step level LR & WD scheduler!")
        lr_schedule_values = utils.cosine_scheduler(
            args.lr,
            args.min_lr,
            args.epochs,
            num_training_steps_per_epoch,
            warmup_epochs=args.warmup_epochs,
            warmup_steps=args.warmup_steps,
        )
    if args.weight_decay_end is None:
        args.weight_decay_end = args.weight_decay
    wd_schedule_values = utils.cosine_scheduler(
        args.weight_decay,
        args.weight_decay_end,
        args.epochs,
        num_training_steps_per_epoch,
    )
    print(
        "Max WD = %.7f, Min WD = %.7f"
        % (max(wd_schedule_values), min(wd_schedule_values))
    )

    utils.auto_load_model(
        args=args,
        model=model,
        model_without_ddp=model_without_ddp,
        optimizer=optimizer,
        loss_scaler=loss_scaler,
        model_ema=model_ema,
    )

    from ast import literal_eval

    target_layers = literal_eval(args.target_layers)
    assert len(target_layers) > 0
    print(f"target layers: {target_layers}")

    print(f"Start training for {args.epochs} epochs")

    if args.ema_annealing_till_end:
        args.ema_start_at = args.epochs * num_training_steps_per_epoch
        print(f"EMA annealing till the end activated")

    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        if log_writer is not None:
            log_writer.set_step(epoch * num_training_steps_per_epoch)
        train_stats = train_one_epoch(
            model,
            model_ema,
            args.ema_start_at,
            args.ema_decay_init,
            args.ema_decay,
            target_layers,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            l1_beta=args.l1_beta,
            log_writer=log_writer,
            start_steps=epoch * num_training_steps_per_epoch,
            lr_schedule_values=lr_schedule_values,
            wd_schedule_values=wd_schedule_values,
            l2_loss=args.l2_loss,
            layer_results=args.layer_results,
            var_w0=args.var_w0, var_w1=args.var_w1, 
            var_margin0=args.var_margin0, var_margin1=args.var_margin1,
            start_lr_decay_at_step=start_lr_decay_at_step,
            loss_scale=args.loss_scale,
            mask_dropout_prob=args.mask_dropout_prob,
            target_layer_norm_last=not args.no_target_layer_norm_last, target_batch_norm=args.target_batch_norm, target_instance_norm=args.target_instance_norm,
            post_target_instance_norm=args.post_target_instance_norm,
            post_target_layer_norm=args.post_target_layer_norm,
            stochastic = args.stochastic,
            lambda_pretraining = args.lambda_pretraining
        )

        if args.output_dir:
            if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
                utils.save_model(
                    args=args,
                    model=model,
                    model_without_ddp=model_without_ddp,
                    optimizer=optimizer,
                    loss_scaler=loss_scaler,
                    epoch=epoch,
                    model_ema=model_ema,
                )

        log_stats = {
            **{f"train_{k}": v for k, v in train_stats.items()},
            "epoch": epoch,
            "n_parameters": n_parameters,
        }

        if args.output_dir and utils.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(
                os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
            ) as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Training time {}".format(total_time_str))


if __name__ == "__main__":
    opts = get_args()
    if opts.output_dir:
        Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
    main(opts)
