import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPast
from typing import Optional, Tuple, Union
from transformers import GPTNeoXForSequenceClassification, GPTNeoXModel
from transformers.modeling_outputs import SequenceClassifierOutputWithPast, BaseModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from dataclasses import dataclass
import torch.nn.functional as F

from trl.modified_reward_model.utils import add_gaussian_noise, pad_and_apply_attention_mask, scatter_to_dim
from trl.modified_reward_model.aggregate_module import CrossAttention


@dataclass
class BaseModelOutputWithPastWithFeedback(BaseModelOutputWithPast):
    raw_last_hidden_state: Optional[Tuple[torch.FloatTensor, ...]] = None
    last_seq_tokens_h: Optional[Tuple[torch.FloatTensor, ...]] = None


class GPTNeoXModelWithExtraInput(GPTNeoXModel):
    def __init__(self, config):
        super().__init__(config)
        self.fw = config.fw
        self.cur_fw = -1
        self.agg = config.agg
        self.lqh = config.lqh
        self.enable_lm = config.enable_lm

        if self.agg == "mlp":
            self.mlp = nn.Sequential(
                nn.Linear(config.hidden_size + config.policy_hidden_size, config.mlp_hidden_size),
                nn.GELU(),
                nn.Linear(config.mlp_hidden_size, config.hidden_size),
                nn.LayerNorm(config.hidden_size)
            )
        if self.agg == "attention":
            self.attention = CrossAttention(hidden_dim=config.hidden_size, policy_hidden_dim=config.policy_hidden_size, num_heads=8)
            self.adapter = nn.Sequential(
                nn.Linear(config.policy_hidden_size, config.mlp_hidden_size),
                nn.GELU(),
                nn.Linear(config.mlp_hidden_size, config.hidden_size),
                nn.LayerNorm(config.hidden_size)
            )
        if self.enable_lm:
            self.mean_net = nn.Sequential(
                nn.Linear(config.hidden_size, config.mlp_hidden_size),
                nn.ReLU(),
                nn.Linear(config.mlp_hidden_size, config.hidden_size)
            )
            self.var_net = nn.Sequential(
                nn.Linear(config.hidden_size, config.mlp_hidden_size),
                nn.ReLU(),
                nn.Linear(config.mlp_hidden_size, config.hidden_size)
            )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        extra_input: Optional[torch.FloatTensor] = None,
        sequence_lengths: Optional[torch.Tensor] = None,
        return_raw_hiddenstates: Optional[bool] = False,
        return_last_seq_tokens_h: Optional[bool] = False,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPast:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_in(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        converted_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        if head_mask is not None:
            converted_head_mask = ~converted_head_mask.bool() * torch.finfo(inputs_embeds.dtype).min
            converted_head_mask = converted_head_mask.to(dtype=self.dtype, device=self.device)
        head_mask = converted_head_mask

        hidden_states = self.emb_dropout(inputs_embeds)

        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None

        raw_hidden_states = None
        last_token_h = None

        for i, layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:
                outputs = self._gradient_checkpointing_func(
                    layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    head_mask[i],
                    use_cache,
                    past_key_values,
                    output_attentions,
                    cache_position,
                    position_embeddings,
                )
            else:
                outputs = layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    head_mask=head_mask[i],
                    layer_past=past_key_values,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **flash_attn_kwargs,
                )
            hidden_states = outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (outputs[1],)

        hidden_states = self.final_layer_norm(hidden_states)

        if return_raw_hiddenstates:
            raw_hidden_states = hidden_states.clone().detach()

        if extra_input is not None:
            if self.agg == "mlp":
                extra_hidden_states = torch.zeros_like(hidden_states)
                extra_hidden_states[:, :extra_input.shape[1], :] = extra_input
                extra_hidden_states[:, -1, :] = hidden_states[:, -1, :]
                hidden_states = (1 - self.fw) * hidden_states + self.fw * self.mlp(torch.cat([hidden_states, extra_hidden_states], dim=2))

            elif self.agg == "attention":
                extra_input = pad_and_apply_attention_mask(extra_input, attention_mask)

                batch_indices = torch.arange(extra_input.shape[0], device=extra_input.device)
                last_seq_tokens_h = hidden_states[batch_indices, sequence_lengths].detach().clone().to(extra_input.dtype)

                last_seq_tokens_h = self.attention(last_seq_tokens_h.to(self.attention.q_proj.weight.dtype),
                                                   extra_input.to(self.attention.q_proj.weight.dtype))

                if return_last_seq_tokens_h:
                    last_token_h = last_seq_tokens_h

                if self.cur_fw != -1:
                    hidden_states[batch_indices, sequence_lengths] = self.cur_fw * last_seq_tokens_h + (1 - min(self.cur_fw, self.fw)) * hidden_states[batch_indices, sequence_lengths]
                else:
                    hidden_states[batch_indices, sequence_lengths] = self.fw * last_seq_tokens_h + (1 - self.fw) * hidden_states[batch_indices, sequence_lengths]

            elif self.agg == "noise":
                hidden_states = add_gaussian_noise(hidden_states, self.fw)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        return BaseModelOutputWithPastWithFeedback(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            raw_last_hidden_state=raw_hidden_states,
            last_seq_tokens_h=last_token_h
        )


class ModifiedRewardModel(GPTNeoXForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.gpt_neox = GPTNeoXModelWithExtraInput(config)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        feedback_hidden_states: Optional[torch.FloatTensor] = None,
        sequence_lengths: Optional[torch.Tensor] = None
    ) -> SequenceClassifierOutputWithPast:
        outputs: BaseModelOutputWithPast = self.gpt_neox(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            extra_input=feedback_hidden_states,
            sequence_lengths=sequence_lengths
        )
        hidden_states = outputs.last_hidden_state
        logits = self.score(hidden_states)

        batch_size = logits.shape[0]
        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            last_non_pad_token = -1
        elif input_ids is not None:
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        else:
            last_non_pad_token = -1
            logger.warning_once(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )