from sched import scheduler

from PIL import Image
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
import torch
import torch.distributed as dist
IGNORE_INDEX = -100

from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward

from transformers.processing_utils import Unpack
from transformers.utils import (
    LossKwargs,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from typing import List, Optional, Tuple, Union
import torch.nn as nn
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.nn.utils import clip_grad_norm
from torch.utils.data import Dataset,DataLoader
from typing import Dict, Optional, Sequence, List
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaRotaryEmbedding,LlamaMLP,LlamaRMSNorm
import transformers
import tokenizers
from packaging import version
import math
import os
import time
import pandas as pd
import csv

from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.utils import logging
from transformers.cache_utils import Cache, DynamicCache, StaticCache

IGNORE_TOKEN_ID = -100
#from test2 import casual_mask, position_ids, past_key_values, use_cache, cache_position, position_embeddings

logger = logging.get_logger(__name__)

def rank0_print(*args):
    if int(os.environ['LOCAL_RANK']) == 0:
        print(*args)

def _prepare_4d_causal_attention_mask_with_cache_position(
    attention_mask: torch.Tensor,
    sequence_length: int,
    target_length: int,
    dtype: torch.dtype,
    device: torch.device,
    min_dtype: float,
    cache_position: torch.Tensor,
    batch_size: int,
):
    """
    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

    Args:
        attention_mask (`torch.Tensor`):
            A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
        sequence_length (`int`):
            The sequence length being processed.
        target_length (`int`):
            The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
        dtype (`torch.dtype`):
            The dtype to use for the 4D attention mask.
        device (`torch.device`):
            The device to plcae the 4D attention mask on.
        min_dtype (`float`):
            The minimum value representable with the dtype `dtype`.
        cache_position (`torch.Tensor`):
            Indices depicting the position of the input sequence tokens in the sequence.
        batch_size (`torch.Tensor`):
            Batch size.
    """
    if attention_mask is not None and attention_mask.dim() == 4:
        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
        causal_mask = attention_mask
    else:

        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
        #attention_mask = attention_mask.to(device)
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            mask_length = attention_mask.shape[-1]
            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
            padding_mask = padding_mask == 0
            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                padding_mask, min_dtype
            )

    return causal_mask


def _update_causal_mask(
        config,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
):
    # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
    # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
    # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
    # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

    if config._attn_implementation == "flash_attention_2":
        if attention_mask is not None and 0.0 in attention_mask:
            return attention_mask
        return None

    # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
    # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
    # to infer the attention mask.
    past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
    using_static_cache = isinstance(past_key_values, StaticCache)

    # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
    if config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
        if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask=attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=False,
        ):
            return None

    dtype, device = input_tensor.dtype, input_tensor.device
    min_dtype = torch.finfo(dtype).min
    sequence_length = input_tensor.shape[1]
    if using_static_cache:
        target_length = past_key_values.get_max_length()
    else:
        target_length = (
            attention_mask.shape[-1]
            if isinstance(attention_mask, torch.Tensor)
            else past_seen_tokens + sequence_length + 1
        )

    # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
    causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask=attention_mask,
        sequence_length=sequence_length,
        target_length=target_length,
        dtype=dtype,
        device=device,
        min_dtype=min_dtype,
        cache_position=cache_position,
        batch_size=input_tensor.shape[0],
    )

    if (
            config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
            and not output_attentions
    ):
        # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
        # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
        # Details: https://github.com/pytorch/pytorch/issues/110213
        causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

    return causal_mask

def preprocess(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    conv = get_conversation_template("vicuna")

    #conv = conversation_lib.default_conversation.copy()
    
    roles = {"human": conv.roles[0], "gpt": conv.roles[1],"system":"system","bing":conv.roles[1],"chatgpt":conv.roles[1],"bard":conv.roles[1],"user":conv.roles[0],"USER":conv.roles[0],"ASSISTANT":conv.roles[1]}

    conversations = []
    
    for i, source in enumerate(sources):
        if len(source)==0:
            continue
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []

        for j, sentence in enumerate(source):
            
            role = roles[sentence["from"]]
            
            #assert role == conv.roles[j % 2], f"{i}"
            if role ==conv.roles[j % 2]:
                conv.append_message(role, sentence["value"])
            else:
                break
 
        conversations.append(conv.get_prompt())

    # Tokenize conversations
    input_ids = tokenizer(
        conversations,
        return_tensors="pt",
        padding="longest",
        max_length=tokenizer.model_max_length,
        truncation=True,
        #use_fast=True
    ).input_ids
    targets = input_ids.clone()

    assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO

    # Mask targets. Only compute loss on the assistant outputs.
    sep = conv.sep + conv.roles[1] + ": "
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        turns = conversation.split(conv.sep2)
        cur_len = 1
        target[:cur_len] = IGNORE_TOKEN_ID
        for i, turn in enumerate(turns):
            if turn == "":
                break
            turn_len = len(tokenizer(turn).input_ids)

            parts = turn.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep
            # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
            instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            if i != 0 and not tokenizer.legacy:
                # The legacy and non-legacy modes handle special tokens differently
                instruction_len -= 1

            # Ignore the user instructions
            target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
            cur_len += turn_len

            if i != 0 and not tokenizer.legacy:
                # The legacy and non-legacy modes handle special tokens differently
                cur_len -= 1

        target[cur_len:] = IGNORE_TOKEN_ID

        if False:  # Inspect and check the correctness of masking
            z = target.clone()
            z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
            rank0_print(tokenizer.decode(z))
            exit()

        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_TOKEN_ID
                rank0_print(
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                    f" #turn = {len(turns) - 1}. (ignored)"
                )
                #rank0_print(conversation,target)

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
    )


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
        super(SupervisedDataset, self).__init__()

        #rank0_print("Formatting inputs...")
        sources = [example["conversations"] for example in raw_data]
        data_dict = preprocess(sources, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        self.attention_mask = data_dict["attention_mask"]
        #for i in self.labels.size(0):
        #    if torch.all(self.labels[i]==-100):
        #        print("all -100",raw_data[i]['id'])

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(
            input_ids=self.input_ids[i],
            labels=self.labels[i],
            attention_mask=self.attention_mask[i],
        )


class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
        super(LazySupervisedDataset, self).__init__()
        self.tokenizer = tokenizer

        #rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.raw_data = raw_data
        self.cached_data_dict = {}

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if i in self.cached_data_dict:
            return self.cached_data_dict[i]
        if len(self.raw_data[i]["conversations"])==0 : 
            print("fucking empty",self.raw_data[i]["id"])
        ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer)
        ret = dict(
            input_ids=ret["input_ids"][0],
            labels=ret["labels"][0],
            attention_mask=ret["attention_mask"][0],
        )
        #print(i,ret['labels'])
        #if torch.all(ret['labels']==-100):
        #    print("all -100",self.raw_data[i]["id"],labels)

        self.cached_data_dict[i] = ret

        return ret
def make_supervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer, data_path,lazy_preprocess
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    dataset_cls = (
        LazySupervisedDataset if lazy_preprocess else SupervisedDataset
    )
    #rank0_print("Loading data...")

    train_json = json.load(open(data_path, "r"))
    train_dataset = dataset_cls(train_json, tokenizer=tokenizer)

    #if data_args.eval_data_path:
    #    eval_json = json.load(open(data_args.eval_data_path, "r"))
    #    eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer)
    #else:
    eval_dataset = None

    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)



class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""
    def __init__(self, tokenizer):
        self.tokenizer=tokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                        for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True,
            padding_value =self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id)
        )


        return  (
                    batch['input_ids'],batch['labels'],batch['attention_mask']
                )

def recv_from_prev_pipeline_rank_(recv_buffer=None):
    """Receive from previous pipeline stage and update the
    input buffer inplace."""
    if not dist.get_rank()!=0:
        assert recv_buffer is not None
        recv_prev_op = torch.distributed.P2POp(
            torch.distributed.irecv, recv_buffer,
            dist.get_rank()-1)
        reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



# TODO: use functions from megatron/p2p
def send_to_next_pipeline_rank(tensor=None):
    """Send output to the next pipeline stage."""
    if not dist.get_rank()!=dist.get_world_size()-1:
        assert tensor is not None
        send_next_op = torch.distributed.P2POp(
            torch.distributed.isend, tensor,
            dist.get_rank()+1)
        reqs = torch.distributed.batch_isend_irecv([send_next_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



def recv_list_from_prev_pipeline_rank(recv_buffers):
    if not dist.get_rank()==0:
        assert recv_buffers is not None and type(recv_buffers) is list
        #print("inside recving",dist.get_rank(),dist.get_world_size())
        recv_prev_ops = [torch.distributed.P2POp(
            torch.distributed.irecv, recv_buffer,
            dist.get_rank()-1) for recv_buffer in recv_buffers]
        reqs = torch.distributed.batch_isend_irecv(recv_prev_ops)

        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()

def send_list_to_next_pipeline_rank(tensors):
    #print("inside sending",dist.get_rank(),dist.get_world_size())
    if not dist.get_rank() == dist.get_world_size()-1:
        
        assert tensors is not None and type(tensors) is list
        send_next_ops = [torch.distributed.P2POp(
            torch.distributed.isend, tensor,
            dist.get_rank() + 1) for tensor in tensors]
        reqs = torch.distributed.batch_isend_irecv(send_next_ops)
        for req in reqs:
            req.wait()
        torch.cuda.synchronize()


def _is_cuda(tensor):
    """Check if a tensor is not none and is cuda."""
    assert tensor is not None
    assert tensor.is_cuda



def _is_cuda_contiguous(tensor):
    """Check if a tensor is not none, is cuda, and is contiguous."""
    _is_cuda(tensor)
    assert tensor.is_contiguous()
def send_token_and_probs_to_first_pipeline_stage(has_early_exited, token_tensor=None, prob_tensor=None, is_final=False):
    signal_tensor = torch.empty(1, dtype=torch.int8, device=torch.cuda.current_device())
    if has_early_exited or is_final:
        signal_tensor[0] = 1
        _is_cuda(token_tensor)
        _is_cuda(prob_tensor)
    else:
        signal_tensor[0] = 0
    dist.send(tensor=signal_tensor, dst=0)
    if has_early_exited or is_final:
        dist.send(tensor=token_tensor, dst=0)
        dist.send(tensor=prob_tensor, dst=0)

def send_probs_and_indices_to_first_pipeline_stage(has_early_exited, probs_tensor=None,indices_tensor=None, is_final=False):
    signal_tensor = torch.empty(1, dtype=torch.int8, device=torch.cuda.current_device())
    if has_early_exited or is_final:
        signal_tensor[0] = 1
        _is_cuda(probs_tensor)
        _is_cuda(indices_tensor)
       
    else:
        signal_tensor[0] = 0
    dist.send(tensor=signal_tensor, dst=0)
    if has_early_exited or is_final:
        dist.send(tensor=probs_tensor, dst=0)
        dist.send(tensor=indices_tensor, dst=0)

def recv_probs_and_indices(has_early_exited,probs_tensor_buffer, indices_tensor_buffer):

    is_contiguous = probs_tensor_buffer.is_contiguous()
    if is_contiguous:
        probs_tensor_ = probs_tensor_buffer
        indices_tensor_ = indices_tensor_buffer
    else:
        probs_tensor_ = torch.empty(probs_tensor_buffer.shape[0],
                                dtype=torch.int64,
                                device=torch.cuda.current_device())
        indices_tensor_ = torch.empty(indices_tensor_buffer.shape[0],
                                dtype=torch.float32,
                                device=torch.cuda.current_device())

    # if first stage has early exit, get tensor directly
    if has_early_exited:
        assert probs_tensor_buffer is not None
        probs_tensor_buffer[...] = probs_tensor_buffer
        indices_tensor_buffer[...] = indices_tensor_buffer
        return 0


    signal_tensor = torch.empty(1, dtype=torch.int8, device=torch.cuda.current_device())

    # get tensor from subsequent stages one by one
    for i in range(1,dist.get_world_size()):
        dist.recv(tensor=signal_tensor, src=i)
        #print(dist.get_rank(), i, signal_tensor[0], "rank0 receiving")
        if signal_tensor[0] == 1:

            dist.recv(tensor=probs_tensor_, src=i)

            dist.recv(tensor=indices_tensor_, src=i)
            break

    #print(f"exit at {i}")
    if not is_contiguous:
        probs_tensor_buffer[...] = probs_tensor_
        indices_tensor_buffer[...] = indices_tensor_
    return i
def recv_token_and_probs(has_early_exited,token_tensor_buffer, prob_tensor_buffer):

    is_contiguous = token_tensor_buffer.is_contiguous()
    if is_contiguous:
        token_tensor_ = token_tensor_buffer
        prob_tensor_ = prob_tensor_buffer
    else:
        token_tensor_ = torch.empty(token_tensor_buffer.shape[0],
                                dtype=torch.int64,
                                device=torch.cuda.current_device())
        prob_tensor_ = torch.empty(prob_tensor_buffer.shape[0],
                                dtype=torch.float32,
                                device=torch.cuda.current_device())

    # if first stage has early exit, get tensor directly
    if has_early_exited:
        assert token_tensor_buffer is not None
        token_tensor_buffer[...] = token_tensor_buffer
        prob_tensor_buffer[...] = prob_tensor_buffer
        return 0


    signal_tensor = torch.empty(1, dtype=torch.int8, device=torch.cuda.current_device())

    # get tensor from subsequent stages one by one
    for i in range(1,dist.get_world_size()):
        dist.recv(tensor=signal_tensor, src=i)
        #print(dist.get_rank(), i, signal_tensor[0], "rank0 receiving")
        if signal_tensor[0] == 1:

            dist.recv(tensor=token_tensor_, src=i)

            dist.recv(tensor=prob_tensor_, src=i)
            break

    #print(f"exit at {i}")
    if not is_contiguous:
        token_tensor_buffer[...] = token_tensor_
        prob_tensor_buffer[...] = prob_tensor_
    return i

def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

    is_last_stage = dist.get_rank()==dist.get_world_size()-1
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if dist.get_rank()==1 and is_last_stage:
        return tensor

    if is_last_stage:
        _is_cuda_contiguous(tensor)
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())
    # Get the group and corresponding source rank.
    src = 3

    torch.distributed.broadcast(tensor, src)

    return tensor


def broadcast_from_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

    is_first_stage = dist.get_rank()==1
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if dist.get_rank()==dist.get_world_size()-1 and is_first_stage:
        return tensor

    if is_first_stage:
        _is_cuda_contiguous(tensor)
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
                             device=torch.cuda.current_device())
    # Get the group and corresponding source rank.
    src = 1

    torch.distributed.broadcast(tensor, src)

    return tensor


def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast tensor values from last stage into the first stage."""

    is_last_stage = dist.get_rank()==3
    is_first_stage = dist.get_rank()==1
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return tensor
    # Only first and last stage pipeline stages need to be involved.
    if is_last_stage or is_first_stage:
        if is_last_stage:
            _is_cuda_contiguous(tensor)
        else:
            tensor = torch.empty(size,
                                 dtype=dtype,
                                 device=torch.cuda.current_device())
        src = 3

        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor, src)
    else:
        tensor = None

    return tensor



def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Copy tensor values from last stage into the first stage.
    Note that the input tensor is updated in place."""

    is_last_stage = dist.get_rank()==3
    is_first_stage = dist.get_rank()==0
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return
    # Only first and last stage pipeline stages need to be involved.
    if is_last_stage or is_first_stage:
        _is_cuda(tensor)
        is_contiguous = tensor.is_contiguous()
        src = 3
        group = [0,3]
        if is_contiguous:
            tensor_ = tensor
        else:
            if is_last_stage:
                tensor_ = tensor.contiguous()
            else:
                tensor_ = torch.empty(size,
                                      dtype=dtype,
                                      device=torch.cuda.current_device())
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor_, src, group)
        # Update the first stage tensor
        if is_first_stage and not is_contiguous:
            tensor[...] = tensor_

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class DecomposedLinear(nn.Module):
    """级数式矩阵分解的线性层"""
    def __init__(self, in_features, out_features, rank,bias):
        super().__init__()

        self.U = nn.Linear(in_features, rank, bias=bias)
        self.V = nn.Linear(rank, out_features, bias=bias)


    def forward(self, x):
        return self.V(self.U(x))



class DecomposedAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config, layer_idx,ratio):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.ratio = ratio
        low_rank = int(self.hidden_size * self.ratio/2)
        self.q_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
        self.q_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)

        self.k_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
        self.k_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)

        self.v_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False)
        self.v_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False)

        self.o_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False)
        self.o_v_proj = nn.Linear(self.num_heads * self.head_dim, low_rank, bias=False)

        # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
        self.rotary_emb = LlamaRotaryEmbedding(config=self.config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_u_proj(self.q_v_proj(hidden_states))
        key_states = self.k_u_proj(self.k_v_proj(hidden_states))
        value_states = self.v_u_proj(self.v_v_proj(hidden_states))

        # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        attn_output = self.o_u_proj(self.o_v_proj(attn_output))

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

class DecomposedFlashAttention2(DecomposedAttention):
    """
    Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if isinstance(past_key_value, StaticCache):
            raise ValueError(
                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
            )

        output_attentions = False

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_u_proj(self.q_v_proj(hidden_states))
        key_states = self.k_u_proj(self.k_v_proj(hidden_states))
        value_states = self.v_u_proj(self.v_v_proj(hidden_states))

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
        # to be able to avoid many of these transpose/reshape/view.
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        dropout_rate = self.attention_dropout if self.training else 0.0

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in the correct dtype just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (LlamaRMSNorm handles it correctly)

        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        attn_output = _flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            position_ids=position_ids,
            dropout=dropout_rate,
            sliding_window=getattr(self, "sliding_window", None),
            use_top_left_mask=self._flash_attn_uses_top_left_mask,
            is_causal=self.is_causal,
            **kwargs,
        )

        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
        attn_output = self.o_u_proj(self.o_v_proj(attn_output))

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

class DecomposedSdpaAttention(DecomposedAttention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_u_proj(self.q_v_proj(hidden_states))
        key_states = self.k_u_proj(self.k_v_proj(hidden_states))
        value_states = self.v_u_proj(self.v_v_proj(hidden_states))

        # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        attn_output = self.o_u_proj(self.o_v_proj(attn_output))

        return attn_output, None, past_key_value

class DecomposedMLP(nn.Module):
    def __init__(self, config,ratio):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.ratio = ratio
        low_rank = int(intermediate_size * hidden_size * self.ratio / (intermediate_size + hidden_size))
        self.gate_u_proj = nn.Linear(low_rank, intermediate_size, bias=False)
        self.gate_v_proj = nn.Linear(hidden_size, low_rank, bias=False)
        
        self.down_u_proj = nn.Linear(low_rank, hidden_size, bias=False)
        self.down_v_proj = nn.Linear(intermediate_size, low_rank, bias=False)
        
        self.up_u_proj = nn.Linear(low_rank, intermediate_size, bias=False)
        self.up_v_proj = nn.Linear(hidden_size, low_rank, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        up = self.up_u_proj(self.up_v_proj(x))
        gate = self.gate_u_proj(self.gate_v_proj(x))
        return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up))
        
LLAMA_ATTENTION_CLASSES = {
    "eager": DecomposedAttention,
    "flash_attention_2": DecomposedFlashAttention2,
    "sdpa": DecomposedSdpaAttention,
}
class DecomposedTransformerLayer(nn.Module):
    """使用分解线性层的Transformer层"""
    def __init__(self, config, layer_idx,rank):
        super().__init__()
        # 自注意力机制中的分解线性层

        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config,layer_idx,rank)
        # 替换FeedForward中的线性层
        self.mlp = DecomposedMLP(config,rank)

        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)


        hidden_states = residual + hidden_states
        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs

class DecomposedModel(nn.Module):
    def __init__(self,rank):
        super().__init__()
        config = LlamaConfig()
        config.hidden_size=4096
        config._attn_implementation="sdpa"
        self.layers = torch.nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(16)]
        )
        for param in self.layers.parameters():
            param.requires_grad = False

        self.num_components = 1
        self.rank = 1024
        self.layers.load_state_dict(torch.load(f'llama_decoder_16layers.pth'))
        self.decomposed_layer = DecomposedTransformerLayer(
            4096,32,11008,self.rank)


    def forward(self,hidden_states,position_ids,past_key_values):
        for idx,layer in enumerate(self.layers):
            #causal_mask,position_ids,past_key_values,use_cache,cache_position,position_embeddings
            layer_outputs = layer(
            hidden_states,
            causal_mask=None,
            position_ids=position_ids,
            past_key_value=past_key_values,
            #output_attentions=output_attentions,
            use_cache=True,
            cache_position=None,
            position_embeddings=None,
            )

            hidden_states = layer_outputs[0]


            next_decoder_cache = layer_outputs[1]
        next_cache = next_decoder_cache
        return  hidden_states,next_cache

class SVD_LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        ratio=1
    ):
        super().__init__()
        self.ratio = ratio
        low_rank = int(intermediate_size * hidden_size * self.ratio / (intermediate_size + hidden_size))
        self.gate_u_proj = nn.Linear(low_rank, intermediate_size, bias=False)
        self.gate_v_proj = nn.Linear(hidden_size, low_rank, bias=False)
        
        self.down_u_proj = nn.Linear(low_rank, hidden_size, bias=False)
        self.down_v_proj = nn.Linear(intermediate_size, low_rank, bias=False)
        
        self.up_u_proj = nn.Linear(low_rank, intermediate_size, bias=False)
        self.up_v_proj = nn.Linear(hidden_size, low_rank, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        up = self.up_u_proj(self.up_v_proj(x))
        gate = self.gate_u_proj(self.gate_v_proj(x))
        return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up))

    
class MyModel(nn.Module):
    def __init__(self, rank, config, model_path,world_size):
        super().__init__()
        config._attn_implementation = 'sdpa' # FIXME: only support sdpa implementation
        interval = config.num_hidden_layers//world_size
        left = config.num_hidden_layers%world_size

        layer_list = []
        for i in range(world_size):
            layer_list.append(interval)
        for i in range(left):
            layer_list[-i-1]+=1
        self.layers = torch.nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(layer_list[rank])]
        )        


        self.layers.load_state_dict(torch.load(f'{model_path}decoder_layers{rank}.pth'))

    def forward(self,hidden_states,causal_mask,position_ids,past_key_values,cache_position):
        
        for idx,layer in enumerate(self.layers):
            #causal_mask,position_ids,past_key_values,use_cache,cache_position,position_embeddings
            layer_outputs = layer(
            hidden_states,
            attention_mask=causal_mask,
            position_ids=position_ids,
            past_key_value=past_key_values,
            #output_attentions=output_attentions,
            use_cache=True,
            cache_position=cache_position,
            position_embeddings=None,
            )

            hidden_states = layer_outputs[0]

            next_decoder_cache = layer_outputs[1]
        next_cache = next_decoder_cache
        

        return  hidden_states,next_cache


class MTPhead(nn.Module):
    def __init__(self,config,rank):
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size

        self.norm1 = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.norm2 = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.linear = nn.Linear(2*config.hidden_size, config.hidden_size)

  
    def forward(self,x,hidden_states,position_ids,past_key_values):

 
        hidden_states = self.norm1(hidden_states)
        x = self.norm2(x)
        x = torch.cat((hidden_states,x),dim=-1)

        x= self.linear(x)
        x= self.layer(
            x,
            causal_mask=None,
            position_ids=position_ids,
            past_key_value=past_key_values,
            #output_attentions=output_attentions,
            use_cache=True,
            cache_position=None,
            position_embeddings=None,
            )


        return x[0],x[1]


class Normhead(nn.Module):

    def __init__(self, config, model_path, inf_flag=False):
        super().__init__()
        #self.rmsnorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head =nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        #if not inf_flag:
        #    self.lm_head.load_state_dict(torch.load(model_path+"lmhead.pth"))
        #    #self.rmsnorm.load_state_dict(torch.load(model_path+"norm.pth"))
        #    self.norm.load_state_dict(torch.load(model_path+"norm.pth"))


    def forward(self, x):
        #x=self.rmsnorm(x)
        x=self.norm(x)
        x=self.lm_head(x)
        return x
    

class MLPhead(nn.Module):

    def __init__(self, config, rank, model_path, inf_flag=False):
        super().__init__()
        self.mlp = LlamaMLP(config)
        self.rmsnorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head =nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        if not inf_flag:
            self.lm_head.load_state_dict(torch.load(model_path+"lmhead.pth"))
            self.rmsnorm.load_state_dict(torch.load(model_path+"norm.pth"))

    def forward(self, x):
        x=self.mlp(x)
        x=self.rmsnorm(x)
        x=self.lm_head(x)
        return x


class Transformerhead(nn.Module):

    def __init__(self,config, model_path, inf_flag=False):
        super().__init__()

        self.trm = LlamaDecoderLayer(config, 0)
        self.rmsnorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head =nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        #if not inf_flag:
        #    self.trm.load_state_dict(torch.load(f"{model_path}next_trm{rank}.pth"))
        #    self.lm_head.load_state_dict(torch.load(model_path+"lmhead.pth"))
        #    self.rmsnorm.load_state_dict(torch.load(model_path+"norm.pth"))

    def forward(self, x,position_ids,past_key_values):
        
        outputs=self.trm(
            x,
            causal_mask=None,
            position_ids=position_ids,
            past_key_value=past_key_values,
            #output_attentions=output_attentions,
            use_cache=True,
            cache_position=None,
            position_embeddings=None,
            )
        x=self.rmsnorm(outputs[0])
        x=self.lm_head(x)
        return x,outputs[1]

def get_interval(rank,world_size):
    if world_size == 4:
        intervals = [3,1,1,0]
        return intervals[rank]
    elif world_size == 5:
        intervals = [3,1,1,1,0]
        return intervals[rank]
    elif world_size == 6:
        intervals = [4,1,1,1,1,0]
        return intervals[rank]
    elif world_size == 3:
        intervals = [2,1,0]
        return intervals[rank]
    elif world_size == 7:
        intervals = [5,3,1,1,1,1,0]
        return intervals[rank]
    elif world_size == 8:
        intervals = [5,3,2,1,1,1,1,0]
        return intervals[rank]
