import copy

from transformers.models.gpt_neox.modeling_gpt_neox import *


def chunk_forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:

    return_dict = return_dict if return_dict is not None else self.config.use_return_dict


    if past_key_values == None:

        batch_size, input_length = input_ids.shape


        position_set.set_align_position(input_length)
        first_chunk_length = position_set.first_chunk
        chunk_width = position_set.context_window_length

        new_past_key_values = None
        first_chunk_past_key_values = None
        new_logits = None

        i = 0
        beg, end = 0, 0+first_chunk_length

        while i < input_length:
            outputs = self.gpt_neox(
                input_ids=input_ids[..., beg:end] if input_ids is not None else None,
                attention_mask=attention_mask[..., beg:end] if attention_mask is not None else None,
                position_ids=position_ids[..., beg:end] if position_ids is not None else None,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds[..., beg:end, :] if inputs_embeds is not None else None,
                past_key_values=past_key_values,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            hidden_states = outputs[0]
            logits = self.embed_out(hidden_states)

            current_input_len_q = logits.shape[1]

            if new_past_key_values == None:
                new_past_key_values = outputs.past_key_values
                first_chunk_past_key_values = copy.deepcopy(new_past_key_values)
                new_logits = copy.deepcopy(logits)
            else:
                _past_key_values = []
                for tup, pkv in zip(list(new_past_key_values), list(outputs.past_key_values)):
                    tup_ = (torch.concat([tup[0], pkv[0][:, :, -current_input_len_q:, :]], dim=-2),
                            torch.cat([tup[1], pkv[1][:, :, -current_input_len_q:, :]], dim=-2))
                    _past_key_values.append(tup_)
                new_past_key_values = tuple(_past_key_values)

                new_logits = torch.concat([new_logits, logits], dim=1)

            i = end
            beg = end

            if end + chunk_width < input_length:
                # 处理前面的 chunk

                # 拼接第一个 chunk
                past_key_values = first_chunk_past_key_values

                # # 拼接前面所有的chunk
                # past_key_values = new_past_key_values

            else:
                # last chunk
                past_key_values = new_past_key_values

            end += chunk_width

        past_key_values = new_past_key_values
        lm_logits = new_logits

    else:
        outputs = self.gpt_neox(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        lm_logits = self.embed_out(hidden_states)

        past_key_values = outputs.past_key_values


    lm_loss = None
    if labels is not None:
        # move labels to correct device to enable model parallelism
        labels = labels.to(lm_logits.device)
        # we are doing next-token prediction; shift prediction scores and input ids by one
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()
        loss_fct = CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

    return CausalLMOutputWithPast(
        loss=lm_loss,
        logits=lm_logits,
        past_key_values=past_key_values,
        hidden_states=None,
        attentions=None,
    )


def attention_chunk_forward(
    self,
    hidden_states: torch.FloatTensor,
    attention_mask: torch.FloatTensor,
    position_ids: torch.LongTensor,
    head_mask: Optional[torch.FloatTensor] = None,
    layer_past: Optional[Tuple[torch.Tensor]] = None,
    use_cache: Optional[bool] = False,
    output_attentions: Optional[bool] = False,
):
    has_layer_past = layer_past is not None

    # Compute QKV
    # Attention heads [batch, seq_len, hidden_size]
    #   --> [batch, seq_len, (np * 3 * head_size)]
    qkv = self.query_key_value(hidden_states)

    # [batch, seq_len, (num_heads * 3 * head_size)]
    #   --> [batch, seq_len, num_heads, 3 * head_size]
    new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
    qkv = qkv.view(*new_qkv_shape)

    # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
    query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
    key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
    value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)

    # 保存不带位置编码的 pass-key-value
    if has_layer_past:
        past_key = layer_past[0]
        past_value = layer_past[1]
        key = torch.cat((past_key, key), dim=-2)
        value = torch.cat((past_value, value), dim=-2)
    present = (key, value) if use_cache else None


    kv_seq_len = key.shape[-2]
    cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)

    # chunk 调整后的位置参数
    train_max_len = position_set.train_length
    push_width = position_set.push_width
    push_pos = (kv_seq_len - 1) % position_set.context_window_length

    push_width = 20

    if kv_seq_len > train_max_len:
        pos1 = torch.arange(kv_seq_len - push_pos, kv_seq_len, dtype=torch.long).to(position_ids.device)
        last_pos = (kv_seq_len - push_pos) // push_width + 1
        pos2_indices = torch.arange(kv_seq_len - push_pos - last_pos, kv_seq_len - push_pos, dtype=torch.long).to(
            position_ids.device)
        pos2_repeat = pos2_indices.repeat(push_width)
        sorted_pos2, _ = torch.sort(pos2_repeat)
        rope_position = torch.concat([sorted_pos2, pos1], dim=0).to(position_ids.device)[None,]
    else:
        rope_position = torch.arange(kv_seq_len, dtype=torch.long).to(position_ids.device)[None,]

    position_ids = rope_position.to(position_ids.device)


    # 分割 旋转 和 不旋转 维度
    query_rot = query[..., : self.rotary_ndims]
    query_pass = query[..., self.rotary_ndims :]
    key_rot = key[..., : self.rotary_ndims]
    key_pass = key[..., self.rotary_ndims :]

    query, _ = apply_rotary_pos_emb(query_rot, query_rot, cos, sin, position_ids[:, -query_rot.shape[2]:])
    _, key = apply_rotary_pos_emb(key_rot, key_rot, cos, sin, position_ids[:, -key_rot.shape[2]:])

    query = torch.cat((query, query_pass), dim=-1)
    key = torch.cat((key, key_pass), dim=-1)

    # # # # 根据 key 长度，引入 log-n 缩放
    # log_n = (torch.arange(1, key_states.shape[2] + 1)[None,][:, None, :, None].log() / np.log(train_max_len)).clip(
    #     1).to(query_states.dtype)
    # query_states = query_states * log_n[:, :, -query_states.shape[2]:, :].to(query_states.device)



    # # Compute token offset for rotary embeddings (when decoding)
    # seq_len = key.shape[-2]
    # if has_layer_past:
    #     seq_len += layer_past[0].shape[-2]
    # cos, sin = self.rotary_emb(value, seq_len=seq_len)
    # query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
    # query = torch.cat((query, query_pass), dim=-1)
    # key = torch.cat((key, key_pass), dim=-1)
    #
    # # Cache QKV values
    # if has_layer_past:
    #     past_key = layer_past[0]
    #     past_value = layer_past[1]
    #     key = torch.cat((past_key, key), dim=-2)
    #     value = torch.cat((past_value, value), dim=-2)
    # present = (key, value) if use_cache else None

    # Compute attention
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

    # Reshape outputs
    attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
    attn_output = self.dense(attn_output)

    outputs = (attn_output, present)
    if output_attentions:
        outputs += (attn_weights,)

    return outputs

def _chunk_attn(self, query, key, value, attention_mask=None, head_mask=None):
    # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
    # compute causal mask from causal mask buffer
    batch_size, num_attention_heads, query_length, attn_head_size = query.size()
    key_length = key.size(-2)

    # dynamically increase the causal mask with the key length, if needed.
    if key_length > self.bias.shape[-1]:
        self._init_bias(key_length, device=key.device)
    causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

    query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
    key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)

    # attn_scores = torch.zeros(
    #     batch_size * num_attention_heads,
    #     query_length,
    #     key_length,
    #     dtype=query.dtype,
    #     device=key.device,
    # )
    # attn_scores = torch.baddbmm(
    #     attn_scores,
    #     query,
    #     key.transpose(1, 2),
    #     beta=1.0,
    #     alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
    # )

    # 修改 attention 矩阵乘法
    attn_scores = torch.matmul(query, key.transpose(1, 2)) / self.norm_factor

    attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)

    mask_value = torch.finfo(attn_scores.dtype).min
    # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
    # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
    mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
    attn_scores = torch.where(causal_mask, attn_scores, mask_value)

    # if attention_mask is not None:
    #     # Apply the attention mask
    #     attn_scores = attn_scores + attention_mask

    attn_weights = nn.functional.softmax(attn_scores, dim=-1)
    attn_weights = attn_weights.to(value.dtype)

    # Mask heads if we want to
    if head_mask is not None:
        attn_weights = attn_weights * head_mask

    attn_weights = self.attention_dropout(attn_weights)

    attn_output = torch.matmul(attn_weights, value)
    return attn_output, attn_weights


class PositionSet:
    context_window_length: int = 1800  # 900 # 512 # 600 # 1900 #600
    align_position: int = 1024  # 1200 # 5000 # 1200

    # 设置最后一个chunk的初始最小长度
    last_context: int = 512 # 512

    total_length: int = 0
    train_length: int = 2048

    # first chunk的长度设置
    first_chunk: int = 100 # 10 # 120 # 80 100

    max_token_len: int = 50
    first_chunk_cache: int = 100

    # 拉长宽度设置
    push_width: int = 50 #100 # 50

    def set_align_position(self, total_length):
        if total_length < self.train_length - self.max_token_len:
            return
        N = (total_length + self.max_token_len - self.last_context - self.first_chunk) // (
                    self.train_length - self.first_chunk - self.first_chunk_cache)
        M = (total_length + self.max_token_len - self.last_context - self.first_chunk) % (
                    self.train_length - self.first_chunk - self.first_chunk_cache)

        if M < 200:
            self.context_window_length = self.train_length - self.first_chunk - self.first_chunk_cache
        else:
            chunk_len = (total_length + self.max_token_len - self.last_context - self.first_chunk) // (N + 1)
            N = (total_length + self.max_token_len - self.last_context - self.first_chunk) // chunk_len
            MM = (total_length + self.max_token_len - self.last_context - self.first_chunk) % chunk_len
            assert MM < 200, "non reasonable setting"
            self.context_window_length = chunk_len
        assert self.context_window_length + self.first_chunk < self.train_length
        chunks_length = self.context_window_length * N + self.first_chunk
        last_chunk_length = total_length + self.max_token_len - chunks_length
        print("first chunk:{}, chunk-size:{}, chunks-length:{}, last-chunk-length:{}, N:{}".format(self.first_chunk, self.context_window_length,
                                                                                             chunks_length, last_chunk_length, N))

# 1. chunk setting
import transformers.models.gpt_neox.modeling_gpt_neox as pythia_modeling
pythia_modeling.GPTNeoXForCausalLM.forward = chunk_forward

# 2. attention chunk
pythia_modeling.GPTNeoXAttention.forward = attention_chunk_forward
pythia_modeling.GPTNeoXAttention._attn = _chunk_attn


# 3. position setting
position_set = PositionSet()

