import torch
import torch.nn as nn
from transformers import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import (
    GPT2Model, GPT2LMHeadModel, GPT2Block, GPT2Attention, GPT2MLP
)


class GPT2BlockNoResidual(GPT2Block):
    def __init__(self, config):
        super().__init__(config)
        
    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=False,
        output_attentions=False,
    ):
        # Self-attention (no residual)
        # hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        hidden_states = attn_outputs[0]  # no addition with previous hidden_states
        outputs = attn_outputs[1:]

        # Optional cross-attention (if enabled) — also without residual
        if encoder_hidden_states is not None:
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            hidden_states = self.ln_cross_attn(hidden_states)
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            hidden_states = cross_attn_outputs[0]
            outputs = outputs + cross_attn_outputs[2:]  # preserve optional cross-attn outputs

        # Feedforward (no residual)
        # hidden_states = self.ln_2(hidden_states)
        hidden_states = self.mlp(hidden_states)

        # Final return
        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs


class GPT2ModelNoResidual(GPT2Model):
    def __init__(self, config):
        super().__init__(config)
        self.h = nn.ModuleList([GPT2BlockNoResidual(config) for _ in range(config.n_layer)])


class GPT2LMHeadModelNoResidual(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = GPT2ModelNoResidual(config)

