# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
import types
from typing import List, Optional, Tuple

import safetensors
import torch
import torch.nn.functional as F
import transformers
from packaging import version
from peft import PeftModel
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, trainer
from transformers.modeling_utils import unwrap_model

from swift.utils import get_logger, torchacc_trim_graph, use_torchacc

logger = get_logger()


# DataLoader
def get_bucket_sizes(max_length: int) -> List[int]:
    """Get the bucket sizes for TorchAcc.
    You can set the environment variable TORCHACC_DATA_BUCKETS to specify
    the bucket sizes. If not set, we use a normal distribution bucketing with
    8 buckets.
    """
    padding_p_base = 2
    if os.getenv("TORCHACC_DATA_BUCKETS") is not None:
        bucket_sizes = [int(x) for x in os.getenv("TORCHACC_DATA_BUCKETS").split(",")]
        bucket_sizes.append(max_length)
    else:
        if (
            os.getenv("TORCHACC_CACHE_PATH") is not None
        ):  # padding strategy when persistent cache is enabled
            padding_p_base = 1.4
        padding_p_base = os.getenv("TORCHACC_PADDING_P_BASE", padding_p_base)
        try:
            padding_p_base = float(padding_p_base)
        except ValueError as e:
            logger.error(
                f"Expect TORCHACC_PADDINF_P_BASE to be a float number, but encountered {padding_p_base}"
            )
            raise e
        bucket_sizes = [16, 32, 48, 64, 96, 128]
        base_size = 256
        while base_size < max_length:
            bucket_sizes.append((int(base_size) + 127) // 128 * 128)
            base_size *= padding_p_base
        bucket_sizes.append(max_length)

    return bucket_sizes


def _get_closet_bucket(bucket_sizes, data_length):
    """Select the one from bucket_sizes that is closest in distance to
    data_length. This is required for TorchAcc.
    """
    closest_length = sys.maxsize
    for b in bucket_sizes:
        if b == data_length or ((b < closest_length) and (b > data_length)):
            closest_length = b

    if closest_length == sys.maxsize:
        bucket_sizes.append(data_length)
        closest_length = data_length

    return closest_length


def pad_and_split_batch(
    padding_to,
    input_ids,
    attention_mask,
    labels,
    loss_scale,
    max_length,
    tokenizer,
    rank,
    world_size,
    padding_right,
):
    if padding_to is None:
        longest_len = input_ids.shape[-1]
        bucket_sizes = get_bucket_sizes(max_length)
        bucket_data_length = _get_closet_bucket(bucket_sizes, longest_len)
        padding_length = bucket_data_length - input_ids.shape[1]
        pad_tuple = (0, padding_length) if padding_right else (padding_length, 0)
        input_ids = F.pad(input_ids, pad_tuple, "constant", tokenizer.pad_token_id)
        attention_mask = F.pad(attention_mask, pad_tuple, "constant", 0)
        if loss_scale:
            loss_scale = F.pad(loss_scale, pad_tuple, "constant", 0.0)
        labels = F.pad(labels, pad_tuple, "constant", -100)

    # manually split the batch to different DP rank.
    batch_size = input_ids.shape[0] // world_size
    if batch_size > 0:
        start = rank * batch_size
        end = (rank + 1) * batch_size
        input_ids = input_ids[start:end, :]
        attention_mask = attention_mask[start:end, :]
        labels = labels[start:end, :]
        if loss_scale:
            loss_scale = loss_scale[start:end, :]
    return input_ids, attention_mask, labels, loss_scale


def ta_train_dataloader(train_dataset, data_collator, sampler, args, batch_size):
    # patch skip_first_batches for customized dataloader.
    def acc_skip_first_batches(dataloader, num_batches=0):
        from accelerate.data_loader import SkipBatchSampler

        batch_sampler = SkipBatchSampler(
            dataloader._loader.batch_sampler, skip_batches=num_batches
        )
        try:
            dataset = dataloader.dataset
        except AttributeError:
            dataset = dataloader._loader.dataset
        dataloader_params = {
            "collate_fn": data_collator,
            "num_workers": args.dataloader_num_workers,
            "pin_memory": args.dataloader_pin_memory,
            "persistent_workers": args.dataloader_persistent_workers,
        }

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params["batch_sampler"] = batch_sampler
            dataloader_params["worker_init_fn"] = trainer.seed_worker

        return ta.AsyncLoader(DataLoader(dataset, **dataloader_params), args.device)

    trainer.skip_first_batches = acc_skip_first_batches

    # dataloader for TorchAcc.
    import torchacc as ta

    dataloader_params = {
        "batch_size": batch_size,
        "collate_fn": data_collator,
        "num_workers": args.dataloader_num_workers,
        "pin_memory": args.dataloader_pin_memory,
        "persistent_workers": args.dataloader_persistent_workers,
    }

    if not isinstance(train_dataset, torch.utils.data.IterableDataset):
        dataloader_params["sampler"] = sampler
        dataloader_params["drop_last"] = args.dataloader_drop_last
        dataloader_params["worker_init_fn"] = trainer.seed_worker

    return ta.AsyncLoader(DataLoader(train_dataset, **dataloader_params), args.device)


def ta_eval_dataloader(eval_dataset, data_collator, sampler, args):
    import torchacc as ta

    dataloader_params = {
        "batch_size": args.eval_batch_size,
        "collate_fn": data_collator,
        "num_workers": args.dataloader_num_workers,
        "pin_memory": args.dataloader_pin_memory,
        "persistent_workers": args.dataloader_persistent_workers,
    }

    if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
        dataloader_params["sampler"] = sampler
        dataloader_params["drop_last"] = args.dataloader_drop_last

    return ta.AsyncLoader(DataLoader(eval_dataset, **dataloader_params), args.device)


def ta_test_dataloader(test_dataset, data_collator, sampler, args):
    import torchacc as ta

    dataloader_params = {
        "batch_size": args.eval_batch_size,
        "collate_fn": data_collator,
        "num_workers": args.dataloader_num_workers,
        "pin_memory": args.dataloader_pin_memory,
        "persistent_workers": args.dataloader_persistent_workers,
    }

    if not isinstance(test_dataset, torch.utils.data.IterableDataset):
        dataloader_params["sampler"] = sampler
        dataloader_params["drop_last"] = args.dataloader_drop_last

    # We use the same batch_size as for eval.
    return ta.AsyncLoader(DataLoader(test_dataset, **dataloader_params), args.device)


# Save/load checkpoint
def ta_save_optimizer_and_scheduler(optimizer, lr_scheduler, output_dir):
    import torch_xla.core.xla_model as xm

    xm.rendezvous("saving_optimizer_states")
    xm.save(
        optimizer.state_dict(),
        os.path.join(output_dir, f"optimizer_{xm.get_ordinal()}.pt"),
        master_only=False,
    )
    xm.save(
        lr_scheduler.state_dict(),
        os.path.join(output_dir, f"scheduler_{xm.get_ordinal()}.pt"),
        master_only=False,
    )
    xm.rendezvous("saving_optimizer_states_done")


def ta_load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint, device):
    import torch_xla.core.xla_model as xm

    optimizer_state = torch.load(
        os.path.join(checkpoint, f"optimizer_{xm.get_ordinal()}.pt"), map_location="cpu"
    )
    lr_scheduler_state = torch.load(
        os.path.join(checkpoint, f"scheduler_{xm.get_ordinal()}.pt"), map_location="cpu"
    )
    xm.send_cpu_data_to_device(optimizer_state, device)
    xm.send_cpu_data_to_device(lr_scheduler_state, device)

    optimizer.load_state_dict(optimizer_state)
    lr_scheduler.load_state_dict(lr_scheduler_state)
    return optimizer, lr_scheduler


def save_ta_ddp_checkpoint(
    self_model, tokenizer, args, output_dir: Optional[str] = None
):
    output_dir = output_dir if output_dir is not None else args.output_dir
    import torch_xla.core.xla_model as xm

    model = self_model

    if xm.is_master_ordinal(local=False):
        os.makedirs(output_dir, exist_ok=True)
        torch.save(args, os.path.join(output_dir, "training_args.bin"))

        xm.mark_step()
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        supported_classes = (PreTrainedModel, PeftModel)
        if not isinstance(model, supported_classes):
            if isinstance(unwrap_model(model), supported_classes):
                unwrap_model(model).save_pretrained(
                    output_dir,
                    is_main_process=args.should_save,
                    state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
                    save_function=xm.save,
                    safe_serialization=args.save_safetensors,
                )
            else:
                logger.info(
                    "Trainer.model is not a `PreTrainedModel`, only saving its state dict."
                )
                state_dict = xm._maybe_convert_to_cpu(model.state_dict())
                if args.save_safetensors:
                    safetensors.torch.save_file(
                        state_dict, os.path.join(output_dir, "model.safetensors")
                    )
                else:
                    torch.save(
                        state_dict, os.path.join(output_dir, "pytorch_model.bin")
                    )
        else:
            model.save_pretrained(
                output_dir,
                is_main_process=args.should_save,
                save_function=xm.save,
                safe_serialization=args.save_safetensors,
                state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
            )
        if tokenizer is not None and args.should_save:
            tokenizer.save_pretrained(output_dir)


def save_ta_fsdp_checkpoint(self_model, tokenizer, args, output_dir):
    import torch_xla.core.xla_model as xm
    from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints

    xm.mark_step()

    if xm.is_master_ordinal(local=False):
        os.makedirs(output_dir, exist_ok=True)
        torch.save(args, os.path.join(output_dir, "training_args.bin"))

    supported_classes = (PreTrainedModel, PeftModel)
    model = self_model._get_underlay_model().module.module
    unwrapped_model = unwrap_model(model)

    xm.rendezvous("saving_checkpoint")
    ckpt = {
        "model": self_model._get_underlay_model().state_dict(),
        "shard_metadata": self_model._get_underlay_model().get_shard_metadata(),
    }
    if isinstance(model, PeftModel):
        ckpt_path = os.path.join(
            output_dir,
            f"rank{args.process_index}-of-{args.global_world_size}-adapter_model.bin",
        )
    else:
        ckpt_path = os.path.join(
            output_dir,
            f"rank{args.process_index}-of-{args.global_world_size}-pytorch_model.bin",
        )
    xm.save(ckpt, ckpt_path, master_only=False)
    # Make sure all ranks have saved checkpoints
    xm.rendezvous("save_full_checkpoints")

    if tokenizer is not None and args.should_save:
        tokenizer.save_pretrained(
            output_dir,
            is_main_process=xm.is_master_ordinal(local=False),
            save_function=xm.save,
        )

    # rank 0 consolidates and saves the whole checkpoint.
    if xm.is_master_ordinal(local=False):
        if isinstance(model, PeftModel):
            ckpt_suffix = "rank*-of-*-adapter_model.bin"
        else:
            ckpt_suffix = "rank*-of-*-pytorch_model.bin"
        full_state_dict, _ = consolidate_sharded_model_checkpoints(
            ckpt_prefix=os.path.join(output_dir, ""),
            ckpt_suffix=ckpt_suffix,
            save_model=False,
        )

        if isinstance(unwrapped_model, supported_classes):
            unwrapped_model.save_pretrained(
                output_dir,
                state_dict=full_state_dict,
                save_function=xm.save,
                safe_serialization=args.save_safetensors,
            )
        else:
            logger.info(
                "Trainer.model is not a `PreTrainedModel`, only saving its state dict."
            )
            if args.save_safetensors:
                safetensors.torch.save_file(
                    full_state_dict, os.path.join(output_dir, "model.safetensors")
                )
            else:
                torch.save(
                    full_state_dict, os.path.join(output_dir, "pytorch_model.bin")
                )

    xm.rendezvous("ckpt_consolidation")
    # delete the sharded checkpoint.
    os.remove(ckpt_path)


def ta_trim_graph():
    if use_torchacc() and torchacc_trim_graph():
        import torchacc as ta

        ta.mark_step()


# Model patch
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    if position_ids is not None:
        cos = cos[position_ids].unsqueeze(unsqueeze_dim)
        sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    else:
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def patch_acc_model(args, model):
    if not args.use_flash_attn:
        logger.warn("Currently use flash attn for torchacc.")
    if args.model_type.startswith("qwen1half") or args.model_type.startswith("qwen2"):
        model = patch_qwen2_model(model)
    elif args.model_type.startswith("qwen"):
        import torchacc as ta

        model = ta.patch_qwen_model(model)
    elif args.model_type.startswith("baichuan"):
        model = patch_baichuan_model(model)
    elif args.model_type.startswith("llama") or args.model_type.startswith("yi"):
        model = patch_llama_model(model)
    elif args.model_type.startswith("chatglm"):
        model = patah_chatglm_model(model)
    return model


def patch_llama_model(model):

    def update_causal_mask(self, *args, **kwargs):
        # attention_mask is not supported in TorchAcc.
        return None

    def llama_attn_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        from torchacc.ops import flash_attn_varlen_xla
        import einops

        bsz, q_len, _ = hidden_states.size()

        query_states = (
            self.q_proj(hidden_states)
            .view(bsz, q_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )
        key_states = (
            self.k_proj(hidden_states)
            .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
            .transpose(1, 2)
        )
        value_states = (
            self.v_proj(hidden_states)
            .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
            .transpose(1, 2)
        )

        kv_seq_len = key_states.shape[-2]
        assert past_key_value is None, "past_key_value is not supported"

        if version.parse(transformers.__version__) >= version.parse("4.36"):
            cos, sin = self.rotary_emb(value_states, position_ids)
            query_states, key_states = apply_rotary_pos_emb(
                query_states, key_states, cos, sin
            )
        else:
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
            query_states, key_states = apply_rotary_pos_emb(
                query_states, key_states, cos, sin, position_ids
            )

        assert not output_attentions, "output_attentions is not supported"

        if past_key_value is not None:
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        past_key_value = (key_states, value_states) if use_cache else None

        # See https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
        # if attention_mask is not None:
        #     value_states = value_states * attention_mask.unsqueeze(1).unsqueeze(-1)
        q = einops.rearrange(query_states, "b h s ... -> (b s) h ...")
        k = einops.rearrange(key_states, "b h s ... -> (b s) h ...")
        v = einops.rearrange(value_states, "b h s ... -> (b s) h ...")
        max_s = q_len
        cu_q_lens = torch.arange(
            0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
        )
        output = flash_attn_varlen_xla(
            q,
            k,
            v,
            cu_q_lens,
            cu_q_lens,
            max_s,
            max_s,
            0.0,
            softmax_scale=None,
            causal=True,
        )
        output = einops.rearrange(output, "(b s) ... -> b s ...", b=bsz)

        return (
            self.o_proj(einops.rearrange(output, "b s h d -> b s (h d)")),
            None,
            past_key_value,
        )

    for layer in model.model.layers:
        layer.self_attn.forward = types.MethodType(llama_attn_forward, layer.self_attn)

    if version.parse(transformers.__version__) >= version.parse("4.38"):
        model.model._update_causal_mask = types.MethodType(
            update_causal_mask, model.model
        )

    return model


def patah_chatglm_model(model):

    def chatglm_apply_rotary_pos_emb(
        x: torch.Tensor, rope_cache: torch.Tensor
    ) -> torch.Tensor:
        # x: [sq, b, np, hn]
        sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3)
        rot_dim = rope_cache.shape[-2] * 2
        x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
        # truncate to support variable sizes
        rope_cache = rope_cache[:sq]
        xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
        rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
        x_out2 = torch.stack(
            [
                xshaped[..., 0] * rope_cache[..., 0]
                - xshaped[..., 1] * rope_cache[..., 1],
                xshaped[..., 1] * rope_cache[..., 0]
                + xshaped[..., 0] * rope_cache[..., 1],
            ],
            -1,
        )
        x_out2 = x_out2.flatten(3)
        return torch.cat((x_out2, x_pass), dim=-1)

    def chatglm_attn_forward(
        self,
        hidden_states,
        attention_mask,
        rotary_pos_emb,
        kv_cache=None,
        use_cache=True,
        **kwargs,
    ):
        # hidden_states: [sq, b, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        if self.multi_query_attention:
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition
                    * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition
                    * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition
                    * self.hidden_size_per_attention_head,
                ],
                dim=-1,
            )
            query_layer = query_layer.view(
                query_layer.size()[:-1]
                + (
                    self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                )
            )
            key_layer = key_layer.view(
                key_layer.size()[:-1]
                + (
                    self.num_multi_query_groups_per_partition,
                    self.hidden_size_per_attention_head,
                )
            )
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (
                    self.num_multi_query_groups_per_partition,
                    self.hidden_size_per_attention_head,
                )
            )
        else:
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                3 * self.hidden_size_per_attention_head,
            )
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(
                mixed_x_layer, 3
            )

        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            query_layer = chatglm_apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = chatglm_apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        # adjust key and value for inference
        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            key_layer = torch.cat((cache_k, key_layer), dim=0)
            value_layer = torch.cat((cache_v, value_layer), dim=0)
        if use_cache:
            kv_cache = (key_layer, value_layer)
        else:
            kv_cache = None

        if self.multi_query_attention:
            key_layer = key_layer.unsqueeze(-2)
            key_layer = key_layer.expand(
                -1,
                -1,
                -1,
                self.num_attention_heads_per_partition
                // self.num_multi_query_groups_per_partition,
                -1,
            )
            key_layer = key_layer.contiguous().view(
                key_layer.size()[:2]
                + (
                    self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                )
            )
            value_layer = value_layer.unsqueeze(-2)
            value_layer = value_layer.expand(
                -1,
                -1,
                -1,
                self.num_attention_heads_per_partition
                // self.num_multi_query_groups_per_partition,
                -1,
            )
            value_layer = value_layer.contiguous().view(
                value_layer.size()[:2]
                + (
                    self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                )
            )

        # ==================================
        # core attention computation
        # ==================================

        from torchacc.ops import flash_attn_varlen_qkvpacked_xla
        import einops

        query_layer, key_layer, value_layer = [
            k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
        ]
        bsz, _, q_len, _ = query_layer.size()
        qkv = torch.stack([query_layer, key_layer, value_layer], dim=2)
        qkv = qkv.transpose(1, 3)
        qkv = einops.rearrange(qkv, "b s ... -> (b s) ...")
        cu_q_lens = torch.arange(
            0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
        )
        context_layer = flash_attn_varlen_qkvpacked_xla(
            qkv, cu_q_lens, q_len, dropout_p=0.0, softmax_scale=None, causal=True
        )
        context_layer = einops.rearrange(context_layer, "(b s) ... -> b s ...", b=bsz)
        context_layer = context_layer.permute(1, 0, 2, 3)
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.core_attention.hidden_size_per_partition,
        )
        context_layer = context_layer.reshape(*new_context_layer_shape)

        # =================
        # Output. [sq, b, h]
        # =================

        output = self.dense(context_layer)

        return output, kv_cache

    def torchacc_swiglu(x):
        x = torch.chunk(x, 2, dim=-1)
        return F.silu(x[0]).to(x[0].dtype) * x[1]

    # patch attention
    for layer in model.transformer.encoder.layers:
        layer.self_attention.forward = types.MethodType(
            chatglm_attn_forward, layer.self_attention
        )
        layer.mlp.activation_func = torchacc_swiglu

    return model


def patch_baichuan_model(model):

    def baichuan_attn_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        import einops

        bsz, q_len, _ = hidden_states.size()

        proj = self.W_pack(hidden_states)
        proj = (
            proj.unflatten(-1, (3, self.hidden_size))
            .unsqueeze(0)
            .transpose(0, -2)
            .squeeze(-2)
        )
        query_states = (
            proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        key_states = (
            proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )
        value_states = (
            proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        )

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        from torchacc.ops import flash_attn_varlen_xla

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)
        q, k, v = [
            einops.rearrange(x, "b s ... -> (b s) ...")
            for x in [query_states, key_states, value_states]
        ]
        cu_q_lens = torch.arange(
            0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
        )
        output = flash_attn_varlen_xla(
            q,
            k,
            v,
            cu_q_lens,
            cu_q_lens,
            q_len,
            q_len,
            0.0,
            softmax_scale=None,
            causal=True,
        )
        output = einops.rearrange(output, "(b s) ... -> b s ...", b=bsz)
        output = self.o_proj(einops.rearrange(output, "b s h d -> b s (h d)"))
        return output, None, past_key_value

    for layer in model.base_model.layers:
        layer.self_attn.forward = types.MethodType(
            baichuan_attn_forward, layer.self_attn
        )

    return model


def patch_qwen2_model(model):

    def update_causal_mask(self, *args, **kwargs):
        # attention_mask is not supported in TorchAcc.
        return None

    def qwen2_attn_forward(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
        position_embeddings=None,
        **kwargs,
    ):

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(
            bsz, q_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        key_states = key_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        value_states = value_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

        # Because the input can be padded, the absolute sequence length depends on the max position id.
        # rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
        rotary_seq_len = kv_seq_len + 1

        if version.parse(transformers.__version__) >= version.parse("4.45"):
            if position_embeddings is None:
                cos, sin = self.rotary_emb(value_states, position_ids)
            else:
                cos, sin = position_embeddings
            query_states, key_states = apply_rotary_pos_emb(
                query_states, key_states, cos, sin
            )
        else:
            cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
            query_states, key_states = apply_rotary_pos_emb(
                query_states, key_states, cos, sin, position_ids
            )

        dropout_rate = 0.0 if not self.training else self.attention_dropout

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in float16 just to be sure everything works as expected.
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # Reshape to the expected shape for Flash Attention
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        from torchacc.ops import flash_attn_varlen_xla
        import einops

        q, k, v = [
            einops.rearrange(x, "b s ... -> (b s) ...")
            for x in [query_states, key_states, value_states]
        ]
        cu_q_lens = torch.arange(
            0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
        )

        attn_output = flash_attn_varlen_xla(
            q,
            k,
            v,
            cu_q_lens,
            cu_q_lens,
            q_len,
            q_len,
            dropout_rate,
            softmax_scale=None,
            causal=True,
        )

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

    def qwen2_forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ):
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
            )
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError(
                "You have to specify either decoder_input_ids or decoder_inputs_embeds"
            )

        if self.gradient_checkpointing and self.training:
            if use_cache:
                use_cache = False

        past_key_values_length = 0

        if use_cache:
            use_legacy_cache = not isinstance(past_key_values, Cache)
            if use_legacy_cache:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            past_key_values_length = past_key_values.get_usable_length(seq_length)

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length,
                seq_length + past_key_values_length,
                dtype=torch.long,
                device=device,
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    attention_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = None
        if use_cache:
            next_cache = (
                next_decoder_cache.to_legacy_cache()
                if use_legacy_cache
                else next_decoder_cache
            )

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
                if v is not None
            )
        from transformers.modeling_outputs import BaseModelOutputWithPast

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    for layer in model.model.layers:
        layer.self_attn.forward = types.MethodType(qwen2_attn_forward, layer.self_attn)

    if version.parse(transformers.__version__) >= version.parse("4.43"):
        model.model._update_causal_mask = types.MethodType(
            update_causal_mask, model.model
        )
    else:
        model.model.forward = types.MethodType(qwen2_forward, model.model)
    return model


def patch_clip_grad_norm(accelerator):
    from accelerate.utils import DistributedType
    from accelerate.optimizer import AcceleratedOptimizer
    import torch_xla.core.xla_model as xm

    def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
        """
        Should be used in place of `torch.nn.utils.clip_grad_norm_`.

        Returns:
            `torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector).

        Example:

        ```python
        >>> from accelerate import Accelerator

        >>> accelerator = Accelerator(gradient_accumulation_steps=2)
        >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)

        >>> for input, target in dataloader:
        ...     optimizer.zero_grad()
        ...     output = model(input)
        ...     loss = loss_func(output, target)
        ...     accelerator.backward(loss)
        ...     if accelerator.sync_gradients:
        ...         accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
        ...     optimizer.step()
        ```
        """
        if self.distributed_type == DistributedType.FSDP:
            self.unscale_gradients()
            parameters = [p for p in parameters]
            for model in self._models:
                if parameters == [p for p in model.parameters()]:
                    return model.clip_grad_norm_(max_norm, norm_type)
        elif self.distributed_type == DistributedType.DEEPSPEED:
            # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
            # We cannot return the gradient norm because DeepSpeed does it.
            return None
        elif self.distributed_type == DistributedType.XLA:
            # Reduce gradients first for XLA
            for acc_opt in self._optimizers:
                if not acc_opt.gradient_state.is_xla_gradients_synced:
                    opt = acc_opt
                    while isinstance(opt, AcceleratedOptimizer):
                        opt = opt.optimizer
                    gradients = xm._fetch_gradients(opt)
                    # Use xm.all_reduce to perform an in-place all-reduce. Recursive all-reduce each tensor
                    # one by one in self.reduce is non-inplace.
                    xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes)
                    # Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
                    acc_opt.gradient_state.is_xla_gradients_synced = True
            if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
                self.unscale_gradients()
                parameters = [p for p in parameters]
                for model in self._models:
                    if parameters == [p for p in model.parameters()]:
                        return model._get_underlay_model().clip_grad_norm_(
                            max_norm, norm_type
                        )
        self.unscale_gradients()
        return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)

    # TODO(baole): This should be removed once accelerate is updated.
    accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator)
    return accelerator


def ta_accelerate(
    model,
    fsdp_num,
    layer_cls_name,
    bf16=True,
    fp16=False,
    gradient_checkpointing=True,
    fsdp_flatten_parameters=False,
):
    """accelerate LLM training using TorchAcc(only available internally)."""
    import torchacc as ta

    assert layer_cls_name is not None

    def get_ta_config():
        config = ta.Config()
        config.compute.fp16 = fp16
        config.compute.bf16 = bf16

        config.memory.gc = gradient_checkpointing
        if config.memory.gc:
            config.memory.gc_cls = {layer_cls_name}

        config.dist.fsdp.size = fsdp_num
        config.dist.fsdp.wrap_layer_cls = {layer_cls_name}
        config.dist.fsdp.flatten_parameters = fsdp_flatten_parameters
        config.dist.dp.size = 1

        if fsdp_num > 1:
            os.environ["ACCELERATE_USE_FSDP"] = "true"

        return config

    ta_config = get_ta_config()
    model = ta.accelerate(model, config=ta_config)
    return model
