import argparse
import contextlib
import logging
import math
import os
import time
from datetime import timedelta
from math import ceil

import torch
import torch.distributed as dist
from datasets import load_from_disk, load_metric
from megatron import initialize_megatron
from megatron.mpu import initialize
from tqdm import tqdm
from transformers import (
    AutoConfig,
    AutoModelForMaskedLM,
    AutoModelForSequenceClassification,
    BertForMaskedLM,
    BertForSequenceClassification,
    SchedulerType,
    ViTForImageClassification,
    ViTForMaskedImageModeling,
    get_scheduler,
    set_seed,
)
from transformers.models.vit.modeling_vit import PatchEmbeddings

from .distributed_utils import DistGroups, init_seq_dp_group
from .glue_dataset import get_glue_dataloader
from .lra_dataset import get_pathx_dataloader
from .megatron_bert import model_provider
from .mim_dataset import get_imagenet_dataloader
from .sp_data_utils import (  # get_qa_eval_batch,
    get_glue_batch,
    get_megatron_lm_batch,
    get_mlm_batch,
    get_mlm_dataloader,
    get_pathx_batch,
)
from .tasks.base.dataset import set_sp_method
from .tasks.base.model_util import (
    replace_bert_attn,
    replace_bert_attn_MQSP_ckptother,
    replace_bert_attn_qallgather,
    wrap_bert_layer_output,
)
from .tasks.qa.dataset import (
    get_qa_dataloader,
    get_qa_eval_batch,
    get_qa_train_batch,
    qa_eval_fn,
)
from .tasks.qa.model import init_qa_model
from .tasks.vit.dataset import get_image_batch, get_image_mask_batch

# copy from run_glue_no_trainer.py
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}


def save_model(model, optimizer, args):
    if dist.get_rank() == 0 and args.output_dir is not None:
        if not os.path.exists(args.output_dir):
            os.mkdir(args.output_dir)
        save_file = os.path.join(args.output_dir, "best_%s.th" % (args.max_seq_length))
        with open(save_file, "wb") as fout:
            torch.save({"model": model.state_dict(), "optimizer": optimizer.state_dict()}, fout)


def load_model(args, model, optimizer):
    load_file = os.path.join(args.output_dir, "best_%s.th" % (args.max_seq_length))
    with open(load_file, "rb") as fin:
        state_dict = torch.load(fin, map_location=torch.cuda.current_device())
        # model =
        model.load_state_dict(state_dict["model"])
        optimizer.load_state_dict(state_dict["optimizer"])
    return model, optimizer


def assertTensorEqual(name, tensor0, tensor1, atol=1e-9, rtol=1e-6):
    tensor0_str = "mean=%.3f std=%.3f max=%.3f" % (
        torch.mean(tensor0), torch.std(tensor0), torch.max(tensor0))
    tensor1_str = "mean=%.3f std=%.3f max=%.3f" % (
        torch.mean(tensor1), torch.std(tensor1), torch.max(tensor1))
    allclose = torch.allclose(tensor0, tensor1, atol=atol, rtol=rtol)
    abs_delta = torch.sum(torch.abs(tensor0 - tensor1))
    max_delta = torch.max((tensor0 - tensor1) * 100 / (tensor1))
    min_delta = torch.min((tensor0 - tensor1) * 100 / (tensor1))
    numel = tensor0.numel()
    flatten0 = tensor0.reshape(-1)
    flatten1 = tensor1.reshape(-1)
    sort_idx0 = torch.argsort(flatten0)
    sort_idx1 = torch.argsort(flatten1)

    diff_num = (~torch.isclose(sort_idx0, sort_idx1)).sum()
    if not allclose:
        msg = (
            f"tensor {name} not allclose numel={numel} abs={abs_delta:.3f} diff_num={diff_num} "
            f"max={max_delta:.4f} min={min_delta:.4f} "
            f"0={tensor0_str} 1={tensor1_str}"
        )
        print(msg)
    else:
        print("tensor {name} equal:%r" % list(tensor0.shape))


def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a task")
    parser.add_argument(
        "--task",
        default="mlm",
        choices=["mlm", "vit_classification", "vit_image_mask", "qa", "pathx"] + list(task_to_keys.keys()),  # glue task
        help="which task to finetune,mlm:Masked Language Modeling, vit_classification: ViTForImageClassification",
    )
    parser.add_argument(
        "--sp_method",
        type=str,
        default=None,
        choices=["colai", "qasp", "qasp_overlap", "megatron", "single"],
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=None,
        help="Local rank of the current node",
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument(
        "--max_seq_length",
        type=int,
        default=4096,
        help="The maximum total input sequence length after tokenization. Sequences longer than this will be "
        "truncated.",
    )
    parser.add_argument(
        "--sp_size",
        type=int,
        default=None,
        help="Sequence parallel size",
    )
    parser.add_argument(
        "--num_micro_q",
        type=int,
        default=16,
        help="QASP number of micro_q",
    )
    parser.add_argument(
        "--profile_step",
        type=int,
        default=0,
        help="If passed, profile each step.",
    )
    parser.add_argument(
        "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
    )
    parser.add_argument(
        "--datasets_path", type=str, default=None, help="A csv or a json file containing the training data."
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default=None,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument("--num-epoch", type=int, default=10)
    parser.add_argument("--output-dir", type=str, default=None)
    parser.add_argument("--restore-dir", type=str, default=None)
    parser.add_argument("--process_bar", action="store_true", default=False, help="show process bar on rank(0)")
    # parser.add_argument(
    #     "--lr_scheduler_type",
    #     type=SchedulerType,
    #     default="linear",
    #     help="The scheduler type to use.",
    #     choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    # )
    parser.add_argument(
        "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--eval_steps",
        type=int,
        help="how many train steps interval when eval",
        default=100,
    )
    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--max_eval_steps",
        type=int,
        default=100,
        help="Total number of eval steps to perform.",
    )
    parser.add_argument(
        "--earlystop",
        action="store_true",
        default=True,
        help="Early stop when eval set loss have not decrese",
    )
    parser.add_argument(
        "--use_max_sp",
        action="store_true",
        default=False,
        help="random init input tensor for max seq length",
    )
    parser.add_argument("--metric", help="metric method", choices=["accuracy"])
    parser.add_argument("--pretrain", help="Pretrain task,training from init", action="store_true", default=False)
    parser.add_argument("--max_image_size", default=224, type=int, help="image size when using in image task")
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")

    parser.add_argument(
        "--checkpoint-activations",
        action="store_true",
        help="Checkpoint activation to allow for training " "with larger models, sequences, and batch sizes.",
    )
    parser.add_argument(
        "--no_micro_q_checkpoint",
        action="store_true",
        help="Checkpoint micro_q",
    )
    parser.add_argument(
        "--mqsp_ckpt_other",
        action="store_true",
        help="Checkpoint micro_q",
    )
    parser.add_argument(
        "--save_tensor_path",
        type=str,
        default=None,
        help=(
            "save batch to file so we can compare layer "
            "output between DDP/mqsp,should be 'something_%s_%s.pth' to format layer_number and rank"
        ),
    )
    parser.add_argument(
        "--load_tensor_path", type=str, default=None, help="load batch and compare,shoule be same with save_batch_patch"
    )
    parser.add_argument(
        "--donot_test", action="store_true", default=False
    )
    args, _ = parser.parse_known_args()
    return args


def get_mp_merge_args(parser):
    """Provide extra arguments required for merging."""
    group = parser.add_argument_group(title="mp merge")

    group.add_argument(
        "--model-type",
        type=str,
        required=True,
        choices=["BERT", "GPT", "RACE", "MNLI", "QQP"],
        help="Type of the mdoel.",
    )
    group.add_argument(
        "--target-pipeline-model-parallel-size",
        type=int,
        default=1,
        help="Degree of pipeline model parallelism in output model.",
    )

    return parser


def preprocess_logits_for_metrics(logits, labels):
    return logits.argmax(dim=-1), labels.argmax(dim=-1)  # .argmax(dim=-1)


def compute_metrics(eval_preds, metric):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics
    labels = labels.reshape(-1)
    preds = preds.reshape(-1)
    if labels.dim() > 1:
        mask = labels != -100
        labels = labels[mask]
        preds = preds[mask]
    return metric.compute(predictions=preds, references=labels)


class ConcatSequence(torch.autograd.Function):
    @staticmethod
    def forward(ctx, sequence_output, group=None):
        """
        concat ViTModel output between rank
        """
        ctx.group = group
        # (B, sub_seq_length+1, hidden_size)
        # first is token is ignore; all_gather sequence
        sequence_output = sequence_output[:, 1:]
        head_tensor = sequence_output[:, :1]
        seq_world_size = group.size()
        sequence_output_list = [torch.zeros_like(sequence_output) for _ in range(seq_world_size)]
        dist.all_gather(sequence_output_list, sequence_output.contiguous(), group=group)
        sequence_output = torch.cat([head_tensor] + sequence_output_list, dim=1)
        return sequence_output.requires_grad_(requires_grad=True)

    @staticmethod
    def backward(ctx, seq_grad):
        # grad: (B, seq_length+1, hidden_size)
        # seq_grad = grad_outputs[0]
        tail_tensor = seq_grad[:, 1:]
        split_grads = [t.contiguous() for t in torch.chunk(tail_tensor, ctx.group.size(), dim=1)]
        c_micro = torch.zeros_like(split_grads[0])
        dist.reduce_scatter(c_micro, split_grads, group=ctx.group)
        new_grads = torch.cat([seq_grad[:, :1], c_micro], dim=1)
        # grad: (B, sub_seq_length+1, hidden_size)
        return new_grads, None


class ConcatSequenceOrigin(torch.autograd.Function):
    @staticmethod
    def forward(ctx, sub_seq_output, group=None):
        """
        concat seq_len output between rank
            rank0 [batch, sub_seq_length, hidden_size]
            rank1 [batch, sub_seq_length, hidden_size]
        after concat:
            rank0 [batch, seq_length, hidden_size]
        """
        ctx.group = group
        # (B, sub_seq_length, hidden_size)
        # first is token is ignore; all_gather sequence
        # print("ConcatSequenceOrigin forward", sequence_output.shape)
        seq_world_size = group.size()
        sequence_output_list = [torch.zeros_like(sub_seq_output, requires_grad=True) for _ in range(seq_world_size)]
        dist.all_gather(sequence_output_list, sub_seq_output.contiguous(), group=group)
        sequence_output = torch.cat(sequence_output_list, dim=1)
        return sequence_output

    @staticmethod
    def backward(ctx, seq_grad):
        """
        split seq_len grad between rank
            rank0 [batch, seq_length, hidden_size]
            rank1 [batch, seq_length, hidden_size]
        after concat:
            rank0 [batch, sub_seq_length, hidden_size]
        """
        split_grads = [t.contiguous() for t in torch.chunk(seq_grad, ctx.group.size(), dim=1)]
        sub_seq_grad = torch.zeros_like(split_grads[0])
        dist.reduce_scatter(sub_seq_grad, split_grads, group=ctx.group)
        # grad: (B, sub_seq_length, hidden_size)
        return sub_seq_grad, None


def general_eval_fn(model, args, eval_iter_or_dataset, eval_length, get_batch_fn, metric):
    losses = []
    device = torch.cuda.current_device()
    world_size = dist.get_world_size()
    with torch.no_grad():
        model.eval()
        # max_length = model.config.n_positions
        if eval_iter_or_dataset is not None:
            eval_iter = iter(eval_iter_or_dataset)
        else:
            eval_iter = None

        for step in tqdm(
            range(eval_length),
            desc="eval",
            disable=not (dist.get_rank() == 0),
        ):
            batch = get_batch_fn(eval_iter)
            # input_ids = batch["input_ids"].size(1)  # (batch, sub_seq_length)
            batch = {k: v.to(device=device, non_blocking=True) for k, v in batch.items()}
            outputs = model(**batch)
            if metric is not None:
                logits = outputs["logits"]  # (batch, num_class,)
                labels = batch["labels"]  # (batch, num_class,)
                if len(logits.shape) != len(labels.shape):  # (batch)
                    labels = torch.nn.functional.one_hot(labels)
                eval_preds = preprocess_logits_for_metrics(logits, labels)
                metrics = compute_metrics(eval_preds, metric)  # {key,v}
                loss = torch.tensor(list(metrics.values()), dtype=torch.float32, device=device)
                dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                losses.append((loss / world_size).unsqueeze(0))
            else:
                loss = outputs[0]
                dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                losses.append((loss / world_size).unsqueeze(0))
        model.train()
        if metric is None:
            losses = torch.cat(losses)
            losses = losses[:eval_length]
            try:
                # print("losses sample:", losses[0])
                perplexity = math.exp(torch.mean(losses))
            except OverflowError:
                perplexity = float("inf")
            return {"perplexity": perplexity}
        else:
            # acc mean
            losses = torch.cat(losses, dim=1)
            return (1 - torch.mean(losses, dim=1)).tolist()


def change_position_embedding(model, args):
    bert_emb = model.bert.embeddings
    position_weight = bert_emb.position_embeddings.weight
    if torch.distributed.get_rank() == 0:
        print("position_embeddings size:", position_weight.size(), bert_emb.position_embeddings.padding_idx)
        print(
            f"exp: method={args.sp_method} batch={args.per_device_train_batch_size} seq_length={args.max_seq_length} "
        )

    if args.max_seq_length > position_weight.size(0):
        # post_weight = torch.empty(
        #     args.max_seq_length - position_weight.size(0),
        #     position_weight.size(1),
        #     device=position_weight.device,
        #     dtype=position_weight.dtype,
        # )
        # position_embeddings -> 512*n

        # assert args.max_seq_length % position_weight.size(0) == 0, "position embedding "
        # repeat rather than random init; show that
        post_weight = position_weight.repeat(ceil(args.max_seq_length / position_weight.size(0)), 1)
        post_weight = post_weight[: args.max_seq_length, :]
        print(f"post_weight.shape={post_weight.shape}")
        # if dist.get_rank() == 0:
        #     torch.nn.init.normal_(post_weight, float(torch.mean(position_weight)), float(torch.std(position_weight)))
        # dist.broadcast(
        #     post_weight,
        #     0,
        # )

        # new_weight = torch.cat([position_weight, post_weight])
        bert_emb.position_embeddings = torch.nn.Embedding(
            args.max_seq_length,
            bert_emb.position_embeddings.weight.shape[1],
            _weight=post_weight,
            max_norm=bert_emb.position_embeddings.max_norm,
            norm_type=bert_emb.position_embeddings.norm_type,
            scale_grad_by_freq=bert_emb.position_embeddings.scale_grad_by_freq,
            padding_idx=bert_emb.position_embeddings.padding_idx,
        )
        bert_emb.register_buffer("position_ids", torch.arange(args.max_seq_length).expand((1, -1)))


def init_mlm_model(args):
    if args.checkpoint_activations:
        config = AutoConfig.from_pretrained(
            args.model_name_or_path,
            gradient_checkpointing=True,
        )
    else:
        config = AutoConfig.from_pretrained(
            args.model_name_or_path,
        )
    if args.pretrain:
        model = BertForMaskedLM(config)
    else:
        model = AutoModelForMaskedLM.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
        )

    device = torch.cuda.current_device()
    change_position_embedding(model, args)
    model.to(device)
    return model


def change_img_size(args, config, model):
    print("change_img_size", args.max_image_size, config.image_size)
    # (change_img_size/config.patch_size)**2/sp_size should be int
    if args.max_image_size != config.image_size:
        # wrap embedding output

        model.vit.embeddings.patch_embeddings = PatchEmbeddings(
            image_size=args.max_image_size,
            patch_size=config.patch_size,
            num_channels=config.num_channels,
            embed_dim=config.hidden_size,
        )
        old_fn = model.vit.embeddings.patch_embeddings.forward
        # output = num_patches
        if args.sp_size > 1:
            sp_size = args.sp_size

            def split_sp_fn(pixel_values, *args, **kwargs):
                embedding_output = old_fn(pixel_values, *args, **kwargs)
                chunks = embedding_output.chunk(sp_size, dim=1)
                return chunks[DistGroups["sp"].rank()]

            # VitEmbedding: split on sequence dim patch_embeddings
            model.vit.embeddings.patch_embeddings.forward = split_sp_fn
        num_patches = model.vit.embeddings.patch_embeddings.num_patches
        assert num_patches % args.sp_size == 0, "Image num_patch need devided by sp_size,now :%s/%s" % (
            num_patches,
            args.sp_size,
        )
        print("split num_patches", args.max_image_size, num_patches, args.sp_size, flush=True)
        if args.sp_size > 1:  # diff rank have save position_embeddings?
            model.vit.embeddings.position_embeddings = torch.nn.Parameter(
                torch.zeros(1, num_patches // args.sp_size + 1, config.hidden_size)
            )
        else:
            model.vit.embeddings.position_embeddings = torch.nn.Parameter(
                torch.zeros(1, num_patches + 1, config.hidden_size)
            )
        config.image_size = args.max_image_size


def init_vit_masked_image_model(args):
    config = AutoConfig.from_pretrained(args.model_name_or_path)
    model = ViTForMaskedImageModeling.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
    )
    change_img_size(args, config, model)
    old_vit_forward = model.vit.forward

    def new_vit_forward(pixel_values, *args, **kwargs):
        output = old_vit_forward(pixel_values, *args, **kwargs)
        # seq_output = output[0]
        # seq_output_after = ConcatSequence.apply(seq_output, DistGroups["sp"])
        # output.last_hidden_state = seq_output_after
        return output

    model.vit.forward = new_vit_forward
    # bool_masked_pos: (batch,sub_num_patch) split at get_image_mask_batch
    # sub_num_patch = num_patch // sp_size
    # bool_masked_pos当前rank只算自己mask那部分和图片的loss，其他部分置为0

    def get_masked_im_loss(pixel_values, reconstructed_pixel_values, bool_masked_pos):
        """
        in VitEmbedding
        if bool_masked_pos is not None:
            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
            # replace the masked visual tokens by mask_tokens
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
        """
        patch_num = model.config.image_size // model.config.patch_size
        # patch_num * patch_num
        sp_size, sp_rank = DistGroups["sp"].size(), DistGroups["sp"].rank()

        batch_size, num_channel, width, height = pixel_values.shape
        # shape: (batch, sub_num_patch)
        flatten_on_patch_pixel = pixel_values.view(batch_size, num_channel, model.config.patch_size, -1)
        patch_step = patch_num * patch_num // sp_size
        patch_start_idx = sp_rank * patch_step
        patch_end_idx = (sp_rank + 1) * patch_step
        part_pixel = flatten_on_patch_pixel[
            :, :, :, model.config.image_size * patch_start_idx : model.config.image_size * patch_end_idx
        ]
        # why patch_num rather than patch_size
        bool_masked_pos_mask = bool_masked_pos.reshape(-1, patch_num, patch_num // sp_size)
        mask = bool_masked_pos_mask.repeat_interleave(patch_num, 2).unsqueeze(1).contiguous()
        print(
            f"pixel_values.shape={pixel_values.shape} \n"
            f"reconstructed_pixel_values.shape={reconstructed_pixel_values.shape}\n",
            f"bool_masked_pos = {bool_masked_pos.shape}\n"
            f"bool_masked_pos_mask={bool_masked_pos_mask.shape}\n"
            f"part_pixel={part_pixel.shape}\n",
        )
        reconstruction_loss = torch.nn.functional.l1_loss(part_pixel, reconstructed_pixel_values, reduction="none")
        # batch, num_channel, width, height = reconstruction_loss.shape
        # width * height
        sp_length = height // sp_size
        # TODO: mask at reconstruction_loss num_patch
        partion_loss = reconstruction_loss[:, :, :, sp_length * sp_rank : sp_length * (sp_rank + 1)]

        masked_im_loss = (partion_loss * mask).sum() / (mask.sum() + 1e-5) / model.config.num_channels
        return masked_im_loss

    # add get_masked_im_loss fn, need patch transformer code so this function will be called:
    # patch /opt/conda/lib/python3.8/site-packages/transformers/models/vit/modeling_vit.py \
    # path_to_mqsp/MQSP_evaluation/modeling_vit.py.patch
    model.get_masked_im_loss = get_masked_im_loss
    return model


def init_vit_classification_model(args):
    config = AutoConfig.from_pretrained(args.model_name_or_path)
    model = ViTForImageClassification.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
    )
    change_img_size(args, config, model)
    old_vit_forward = model.vit.forward

    def new_vit_forward(pixel_values, *args, **kwargs):
        output = old_vit_forward(pixel_values, *args, **kwargs)
        seq_output = output[0]
        # only use  seq_output[:,0,:], we can broadcast or using this concat
        seq_output_after = ConcatSequenceOrigin.apply(seq_output[:, :1, :], DistGroups["sp"])
        output.last_hidden_state = seq_output_after

        return output

    model.vit.forward = new_vit_forward
    return model


def init_glue_model(args):
    """
    bert-base-cased
    """
    raw_datasets = load_from_disk(os.path.join(args.datasets_path, args.task))
    # need load dataset to get num_labels
    if args.task is not None:
        is_regression = args.task == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)
    config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task)
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
    )
    return model


def init_pathx_model(args):
    """
    bert-base-cased
    """
    if args.checkpoint_activations:
        config = AutoConfig.from_pretrained(
            args.model_name_or_path,
            gradient_checkpointing=True,
            num_labels=2,
        )
    else:
        config = AutoConfig.from_pretrained(
            args.model_name_or_path,
            num_labels=2,
        )
    # tmp cfg
    config.hidden_size = 64
    config.num_attention_heads = 8
    config.num_hidden_layers = 6
    config.intermediate_size = config.hidden_size * 4
    config.hidden_dropout_prob = 0.0
    config.attention_probs_dropout_prob = 0.0
    config.max_position_embeddings = args.max_seq_length

    if args.pretrain:
        model = BertForSequenceClassification(config)
    else:
        raise ValueError
        model = AutoModelForMaskedLM.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
        )

    device = torch.cuda.current_device()
    model.to(device)

    # concat before pooler
    old_pooler_forward = model.bert.pooler.forward

    def new_pooler_forward(hidden_states):
        hidden_states = ConcatSequenceOrigin.apply(hidden_states, DistGroups["sp"])
        return old_pooler_forward(hidden_states)

    model.bert.pooler.forward = new_pooler_forward
    return model


def main():
    args = parse_args()
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
    )
    # logger = logging.getLogger()  # root logger

    # dist group
    if args.seed:
        set_seed(args.seed)
    torch.cuda.synchronize()
    device = torch.device("cuda:{}".format(args.local_rank))
    torch.cuda.set_device(device)
    compare_mode = args.save_tensor_path or args.load_tensor_path
    # if compare_mode:
    # args.max_train_steps = 1
    if args.sp_method == "megatron":
        initialize_megatron(extra_args_provider=get_mp_merge_args, ignore_unknown_args=True)
        # config = AutoConfig.from_pretrained(args.model_name_or_path)
        # many arg from args
        model = model_provider()
        old_forward = model.forward

        def new_forward(*args, **kwargs):
            loss_mask = kwargs.get("attention_mask")
            output_tensor = old_forward(*args, **kwargs)
            lm_loss_, sop_logits = output_tensor
            lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
            return (lm_loss,)

        model.forward = new_forward
        if torch.distributed.get_rank() == 0:
            print(model)

    else:
        torch.distributed.init_process_group(backend="nccl", timeout=timedelta(seconds=600))
        if args.sp_size is None:
            args.sp_size = torch.distributed.get_world_size()
        task2method = {
            "mlm": init_mlm_model,
            "vit_classification": init_vit_classification_model,
            "vit_image_mask": init_vit_masked_image_model,
            "qa": init_qa_model,
            "pathx": init_pathx_model,
            "default": init_glue_model,
        }
        init_method = task2method.get(args.task, task2method["default"])
        model = init_method(args)
    if args.checkpoint_activations and args.sp_method != "megatron":
        model.gradient_checkpointing_enable()
    # sp_size =
    init_seq_dp_group(args.sp_size, is_all_dp=False)
    if args.sp_method == "megatron":
        # change process group
        dist.destroy_process_group(initialize._TENSOR_MODEL_PARALLEL_GROUP)

        initialize._TENSOR_MODEL_PARALLEL_GROUP = DistGroups["sp"]
    # sp inject
    if dist.get_rank() == 0:
        numel_sum = sum([p.numel() for p in model.parameters()])
        print(
            "before replace: model have:%d elements args.checkpoint_activations:%s"
            % (numel_sum, args.checkpoint_activations)
        )
    model.to(device)

    with torch.no_grad():
        tensors = []
        for name, param in model.named_parameters():
            if "attention.self" in name:
                flatten_tensor = param.data.flatten()
                tensors.append(flatten_tensor)
            if "attention.attention" in name:
                flatten_tensor = param.data.flatten()
                tensors.append(flatten_tensor)
        if tensors:
            all_param_tensor = torch.cat(tensors)
            print(
                "before replace rank %d mean=%.3f std=%.3f sum=%.3f"
                % (
                    dist.get_rank(),
                    torch.mean(all_param_tensor),
                    torch.std(all_param_tensor),
                    torch.sum(all_param_tensor),
                ),
                flush=True,
            )
    if args.sp_method not in ("megatron", "single"):  # no replace
        if args.task in ("mlm", "qa", "pathx") or args.task in task_to_keys:
            to_replace_model = model.bert.encoder
        else:
            to_replace_model = model.vit.encoder
        if args.sp_method == "colai":
            replace_bert_attn(to_replace_model)
            set_sp_method("colai")
        else:
            if args.mqsp_ckpt_other:
                assert not args.checkpoint_activations, "mqsp finer-grained ckpt needs no layer ckpt"
                replace_fn = replace_bert_attn_MQSP_ckptother
            else:
                replace_fn = replace_bert_attn_qallgather
            if args.sp_method == "qasp":
                replace_fn(
                    to_replace_model, num_micro_q=args.num_micro_q, micro_q_checkpoint=not args.no_micro_q_checkpoint
                )
                set_sp_method("qasp")
            elif args.sp_method == "qasp_overlap":
                replace_fn(
                    to_replace_model,
                    num_micro_q=args.num_micro_q,
                    async_op=True,
                    post_reduce_scatter=True,
                    micro_q_checkpoint=not args.no_micro_q_checkpoint,
                )
                set_sp_method("qasp_overlap")
        # if compare_mode:
            # wrap_bert_layer_output(to_replace_model, args.save_tensor_path, args.load_tensor_path)
    if dist.get_rank() == 0:
        numel_sum = sum([p.numel() for p in model.parameters()])
        print("after replace: model have:%d elements" % numel_sum)
        print(model)
    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.

    # 1. to GPU 2.  make optimizer 3. DDP
    model.to(device)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # mix precision
    # enable_amp = args.amp_opt_level is not None and args.amp_opt_level in ("O0", "O1", "O2", "O3")
    # model, optimizer = amp.initialize(
    #     model, optimizer, enabled=enable_amp, opt_level=args.amp_opt_level, loss_scale="dynamic"
    # )
    # if args.custom_transformerlayer:
    #     model = replace_with_deepspeed_transformer(
    #         BertLayer, model, config, args.per_device_train_batch_size, max_seq_length, args.seed, fp16=enable_amp
    #     )
    #     model.to(device)

    model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[device],
        find_unused_parameters=True,
    )
    model._set_static_graph()

    # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
    # shorter in multiprocess)

    # lr scheduler
    # Scheduler and math around the number of training steps.
    # num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    # if args.max_train_steps is None:
    #     args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # else:
    #     args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
    get_eval_batch_fn = None
    if args.restore_dir:
        load_model(args, model, optimizer)
    if args.task == "mlm":
        get_dataloader_fn = get_mlm_dataloader
        get_batch_fn = get_mlm_batch
        kwargs = dict(
            datasets_path=args.datasets_path,
            model_name_or_path=args.model_name_or_path,
            train_file=args.train_file,
            per_device_train_batch_size=args.per_device_train_batch_size,
            per_device_eval_batch_size=args.per_device_eval_batch_size,
            max_seq_length=args.max_seq_length,
        )
        if args.sp_method == "megatron":
            get_batch_fn = get_megatron_lm_batch
    elif args.task == "vit_classification":
        get_dataloader_fn = get_imagenet_dataloader
        get_batch_fn = get_image_batch
        kwargs = dict(
            datasets_path=args.datasets_path,
            model_name_or_path=args.model_name_or_path,
            per_device_train_batch_size=args.per_device_train_batch_size,
            per_device_eval_batch_size=args.per_device_eval_batch_size,
            max_seq_length=args.max_seq_length,
            max_image_size=args.max_image_size,
        )
    elif args.task == "vit_image_mask":
        get_dataloader_fn = get_imagenet_dataloader
        get_batch_fn = get_image_mask_batch
        kwargs = dict(
            datasets_path=args.datasets_path,
            model_name_or_path=args.model_name_or_path,
            per_device_train_batch_size=args.per_device_train_batch_size,
            per_device_eval_batch_size=args.per_device_eval_batch_size,
            max_seq_length=args.max_seq_length,
            max_image_size=args.max_image_size,
        )
    elif args.task == "qa":
        get_dataloader_fn = get_qa_dataloader
        get_batch_fn = get_qa_train_batch
        get_eval_batch_fn = get_qa_eval_batch
        kwargs = dict(
            datasets_path=args.datasets_path,
            model_name_or_path=args.model_name_or_path,
            per_device_train_batch_size=args.per_device_train_batch_size,
            per_device_eval_batch_size=args.per_device_eval_batch_size,
            max_seq_length=args.max_seq_length,
        )
    elif args.task == "pathx":
        get_dataloader_fn = get_pathx_dataloader
        get_batch_fn = get_pathx_batch
        kwargs = dict(
            datasets_path=args.datasets_path,
            per_device_train_batch_size=args.per_device_train_batch_size,
            per_device_eval_batch_size=args.per_device_eval_batch_size,
        )
    else:
        get_dataloader_fn = get_glue_dataloader
        get_batch_fn = get_glue_batch
        kwargs = dict(
            datasets_path=args.datasets_path,
            task_name=args.task,
            model_name_or_path=args.model_name_or_path,
            per_device_train_batch_size=args.per_device_train_batch_size,
            per_device_eval_batch_size=args.per_device_eval_batch_size,
            max_seq_length=args.max_seq_length,
        )
    if get_eval_batch_fn is None:
        get_eval_batch_fn = get_batch_fn

    if DistGroups["sp"].rank() == 0:
        numel_sum = sum([p.numel() for p in model.parameters()])
        print("model have:%d elements" % numel_sum)
        train_dataloader, eval_dataloader, test_dataloader, tokenizer = get_dataloader_fn(**kwargs)
    else:
        train_dataloader, eval_dataloader, test_dataloader = None, None, None

    if DistGroups["sp"].rank() == 0:
        train_length = len(train_dataloader)
        eval_length = len(eval_dataloader)
        test_length = len(test_dataloader)
        # tokenizer_length = len(tokenizer)
        length_tensor = torch.tensor(
            [train_length, eval_length, test_length],
            dtype=torch.int32,
            device=torch.cuda.current_device(),
        )
        torch.distributed.broadcast(length_tensor, src=0)
    else:
        length_tensor = torch.empty(3, dtype=torch.int32, device=torch.cuda.current_device())
        torch.distributed.broadcast(length_tensor, src=0)
        train_iter = None
        eval_iter = None
        train_length, eval_length, test_length = length_tensor.tolist()
    if args.max_train_steps is None:
        args.max_train_steps = args.num_epoch * train_length
    else:
        args.num_epoch = math.ceil(args.max_train_steps / train_length)
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )
    # model.module.resize_token_embeddings(tokenizer_length)

    self_params = []
    for name, p in model.named_parameters():
        if "attention.self" in name or "query_key_value" in name or "attention.attention" in name:
            self_params.append(p)
    # train_length = 10  # debug
    min_perplexity = float("inf")
    CHECK_EARLYSTOP_TIMES = 10  # check 10 times
    should_stop = torch.tensor([0], dtype=torch.int32, device=device)
    eval_fn = general_eval_fn
    if args.task in ["vit_classification", "pathx"]:
        metric_file = os.path.join(os.path.dirname(__file__), "accuracy.py")
        metric = load_metric(metric_file)
    elif args.task in task_to_keys:
        metric_file = os.path.join(os.path.dirname(__file__), "glue.py")
        metric = load_metric(metric_file, args.task)
    elif args.task == "qa":
        metric_file = os.path.join(os.path.dirname(__file__), "squad.py")
        metric = load_metric(metric_file)
        eval_fn = qa_eval_fn
    else:
        metric = None
    for epoch in range(args.num_epoch):
        if train_dataloader is not None:
            if args.task == "mlm":  # with random
                train_dataloader.sampler.set_epoch(epoch)
            train_iter = iter(train_dataloader)
        else:
            train_iter = None

        t = time.time()
        if torch.distributed.get_rank() == 0:
            print(f"Epoch {epoch} start.")
            t = time.time()
        pbar = tqdm(
            range(train_length),
            total=train_length,
            disable=not (dist.get_rank() == 0 and args.process_bar),
            desc=f"Epoch:{epoch}",
        )
        with torch.no_grad():
            tensors = []
            for param in self_params:
                flatten_tensor = param.data.flatten()
                tensors.append(flatten_tensor)
            if tensors:
                all_param_tensor = torch.cat(tensors)
                print(
                    "rank %d epoch=%d mean=%.3f std=%.3f sum=%.3f"
                    % (
                        dist.get_rank(),
                        epoch,
                        torch.mean(all_param_tensor),
                        torch.std(all_param_tensor),
                        torch.sum(all_param_tensor),
                    ),
                    flush=True,
                )
        model.train()

        torch.cuda.synchronize()
        train_cost = 0
        for step in pbar:
            this_rank = dist.get_rank()
            _prof = (
                torch.profiler.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ]
                )
                if step and args.profile_step and step % args.profile_step == 0
                else contextlib.nullcontext()
            )
            if step and args.profile_step and step % args.profile_step == 0:
                torch.cuda.synchronize()
                dist.barrier()
            if args.load_tensor_path:  # load batch tensor from file
                input_filename = args.load_tensor_path  # % ("input", this_rank)

                with open(input_filename, "rb") as fin:
                    data_b = torch.load(
                        fin,
                    )
                local_world_size = DistGroups["sp"].size()
                local_rank = DistGroups["sp"].rank()
                seq_length = data_b["input_ids"].size(1)
                sub_seq_length = seq_length // local_world_size
                sub_seq_start = local_rank * sub_seq_length
                sub_seq_end = (local_rank + 1) * sub_seq_length
                input_ids = data_b["input_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
                token_type_ids = data_b["token_type_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
                if args.sp_method in ("colai", "megatron", "single"):
                    attention_mask = data_b["attention_mask"].contiguous()
                elif args.sp_method.startswith("qasp"):
                    attention_mask = data_b["attention_mask"][:, sub_seq_start:sub_seq_end].contiguous()
                else:
                    raise ValueError
                # position_ids = full_position_ids[:, sub_seq_start:sub_seq_end].long().contiguous()
                start_positions = data_b["start_positions"].long().contiguous()
                end_positions = data_b["end_positions"].long().contiguous()
                batch = {
                    "input_ids": input_ids,
                    "token_type_ids": token_type_ids,
                    "attention_mask": attention_mask,
                    "start_positions": start_positions,
                    "end_positions": end_positions,
                }
            else:
                batch = get_batch_fn(train_iter)

            if args.use_max_sp:
                # random tensor init,show how many memory will be used
                sub_seq_length = args.max_seq_length // args.sp_size
                input_ids = batch["input_ids"]
                real_seq_length = input_ids.size(1)
                if step == 0 and dist.get_rank() == 0:
                    print(dist.get_rank(), "before repeat input:", {k: v.shape for k, v in batch.items()})
                if sub_seq_length > real_seq_length:
                    # torch.randn -> 0,1
                    repeat_times = sub_seq_length // real_seq_length + 1
                    batch = {
                        k: tensor.repeat_interleave(repeat_times, dim=1)[:, :sub_seq_length].contiguous()
                        for k, tensor in batch.items()
                    }

            batch = {k: v.to(device=device, non_blocking=True) for k, v in batch.items()}
            if step == 0 and dist.get_rank() == 0:
                print(dist.get_rank(), "input:", {k: v.shape for k, v in batch.items()})
            with _prof as prof:
                outputs = model(**batch)
                # (masked_lm_loss, prediction_scores, ...)

                loss = outputs[0]
                if step == 0 and dist.get_rank() == 0:
                    print(dist.get_rank(), "output:", {k: v.shape for k, v in outputs.items()}, "loss:", loss)
                loss.backward()
                if args.save_tensor_path and step == 0:
                    focus_params = []
                    focus_param_names = {
                        "module.bert.encoder.layer.21.attention.output.dense.bias",
                        "module.bert.encoder.layer.21.output.dense.bias",
                        "module.bert.encoder.layer.21.attention.output.dense.weight",
                        "module.bert.encoder.layer.21.attention.self.query.weight",
                        "module.bert.encoder.layer.21.attention.self.query.bias",
                        "module.bert.encoder.layer.21.output.dense.weight",
                        "module.bert.encoder.layer.14.attention.output.LayerNorm.bias",
                        "module.bert.encoder.layer.14.attention.output.LayerNorm.weight",
                    }
                    for name, p in model.named_parameters():
                        if name in focus_param_names:
                            focus_params.append((name, p))
                    torch.cuda.synchronize()

                    out_filename = args.save_tensor_path % ("input", this_rank)
                    outd = {name + ".grad": p.grad for name, p in focus_params}
                    outd.update(batch)
                    outd.update({name: p.data for name, p in focus_params})
                    with open(out_filename, "wb") as fout:
                        torch.save(outd, fout)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                pbar.set_description(f"Epoch: {epoch} {loss:.3f}")
            if step and args.profile_step and step % args.profile_step == 0 and dist.get_rank() == 0:
                prof.export_chrome_trace(f"trace-{step}.json")
            total_step = train_length * epoch + step
            if total_step and total_step % args.eval_steps == 0:

                torch.cuda.synchronize()
                train_cost += time.time() - t
                if args.task == "qa":
                    eval_iter = eval_dataloader
                elif eval_dataloader is not None:
                    eval_iter = iter(eval_dataloader)
                else:
                    eval_iter = None
                perplexity = eval_fn(
                    model, args, eval_iter, min(eval_length, args.max_eval_steps), get_eval_batch_fn, metric
                )
                if isinstance(perplexity, list):
                    perplexity = perplexity[0]
                elif isinstance(perplexity, dict):
                    perplexity = perplexity.get("perplexity", list(perplexity.values())[0])
                mega_bytes = 1024.0 * 1024.0
                if dist.get_rank() == 0:
                    print(
                        f"Step {total_step},method={args.sp_method} ,perplexity={perplexity:.2f}"
                        f" memory (MB) | allocated: {torch.cuda.memory_allocated() / mega_bytes:.1f}"
                        f" memory (MB) | max allocated: {torch.cuda.max_memory_allocated() / mega_bytes:.1f}"
                        f" memory (MB) | reserved: {torch.cuda.memory_reserved() / mega_bytes:.1f}"
                        f" memory (MB) | max reserved: {torch.cuda.max_memory_reserved() / mega_bytes:.1f}",
                        flush=True,
                    )
                if perplexity < min_perplexity:
                    min_perplexity = perplexity
                    CHECK_EARLYSTOP_TIMES = 10
                    if dist.get_rank() == 0:
                        save_model(model, optimizer, args)
                elif perplexity > min_perplexity * 1.1:
                    CHECK_EARLYSTOP_TIMES -= 1

                if CHECK_EARLYSTOP_TIMES <= 0:
                    should_stop = torch.tensor([1], dtype=torch.int32, device=device)
                dist.all_reduce(should_stop)
                if int(should_stop) > 0:
                    break
                t = time.time()

            if total_step > args.max_train_steps:
                print(f"reach max_train_steps:{total_step} ,finish")
                should_stop = torch.tensor([1], dtype=torch.int32, device=device)
                break
        train_cost += time.time() - t
        if int(should_stop) > 0 and dist.get_rank() == 0:
            print(f"early stop on Step {total_step}")
            break

        if dist.get_rank() == 0:
            print(
                f"Step {total_step},method={args.sp_method} speed={train_length/train_cost:.3f} it/s",
                flush=True,
            )
        torch.cuda.synchronize()


if __name__ == "__main__":
    main()
