import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPast
from typing import Optional, Tuple
from transformers import LlamaForSequenceClassification,LlamaModel
from transformers.modeling_outputs  import SequenceClassifierOutputWithPast,BaseModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from dataclasses import dataclass

import torch.nn.functional as F

from trl.modified_reward_model.utils import add_gaussian_noise, pad_and_apply_attention_mask, scatter_to_dim

from trl.modified_reward_model.aggregate_module import CrossAttention

# from transformers.masking_utils import create_causal_mask

@dataclass
class BaseModelOutputWithPastWithFeedback(BaseModelOutputWithPast):
    raw_last_hidden_state: Optional[Tuple[torch.FloatTensor, ...]] = None
    last_seq_tokens_h: Optional[Tuple[torch.FloatTensor, ...]] = None



class LlamaModelWithExtraInput(LlamaModel):
    def __init__(self, config):
        super().__init__(config)
        self.fw=config.fw 
        self.cur_fw=-1
        self.agg=config.agg
        self.lqh=config.lqh
        self.enable_lm=config.enable_lm
        if self.agg == "mlp":
            self.mlp = nn.Sequential(
                nn.Linear(config.hidden_size + config.policy_hidden_size, config.mlp_hidden_size), 
                nn.GELU(),
                nn.Linear(config.mlp_hidden_size, config.hidden_size), 
                nn.LayerNorm(config.hidden_size)
            )
        if self.agg == "attention":
            self.attention=CrossAttention(hidden_dim=config.hidden_size, policy_hidden_dim=config.policy_hidden_size, num_heads=8) #
            # an adapter to align the sementic space between the policy and the reward model

            self.adapter= nn.Sequential(
                nn.Linear(config.policy_hidden_size, config.mlp_hidden_size), 
                nn.GELU(),
                nn.Linear(config.mlp_hidden_size, config.hidden_size), 
                nn.LayerNorm(config.hidden_size)
            )
        if self.enable_lm:
            self.mean_net=nn.Sequential(
            nn.Linear(config.hidden_size, config.mlp_hidden_size),
            nn.ReLU(),
            nn.Linear(config.mlp_hidden_size, config.hidden_size)
            )
            self.var_net=nn.Sequential(
            nn.Linear(config.hidden_size, config.mlp_hidden_size),
            nn.ReLU(),
            nn.Linear(config.mlp_hidden_size, config.hidden_size)
            )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        extra_input: Optional[torch.FloatTensor] = None, #[B,S-1,D]
        sequence_lengths: Optional[torch.Tensor] = None,
        return_raw_hiddenstates: Optional[bool] = False,
        return_last_seq_tokens_h: Optional[bool] = False,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs]
    ) -> BaseModelOutputWithPast:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) # [B,S,D]
        
        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # if attention_mask is None or attention_mask.all():
        #     print("error, write 1")
        #     print(attention_mask)
        #     from time import sleep
        #     sleep(10000)
            # exit(0)


        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        # TODO whether to return 
        raw_hidden_states=None
        last_token_h=None
        
    
        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    partial(decoder_layer.__call__, **flash_attn_kwargs),
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **flash_attn_kwargs,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        if return_raw_hiddenstates:
            raw_hidden_states= hidden_states.clone().detach() # keep the raw hidden states for compare
        
        # concat extra input
        if extra_input is not None:
            # use the last token's hiddenstate 
            if self.agg=="mlp":
                extra_hidden_states=torch.zeros_like(hidden_states)
                extra_hidden_states[:,:extra_input.shape[1],:]=extra_input
                extra_hidden_states[:,-1,:]=hidden_states[:,-1,:]
                # concat+mlp
                # hidden_states=hidden_states+ 1/(self.t+1) * self.mlp(torch.cat([hidden_states, extra_hidden_states],dim=2))
                hidden_states=(1-self.fw) * hidden_states+ self.fw * self.mlp(torch.cat([hidden_states, extra_hidden_states],dim=2))

            elif self.agg=="attention":
                # set pad tokens hiddenstate as 0 to avoid attend aggregation
                extra_input=pad_and_apply_attention_mask(extra_input,attention_mask).to(self.attention.q_proj.weight.dtype)

                # align the semantic space between policy and reward model
                # if extra_input.shape[-1]!= hidden_states.shape[-1]:
                #     extra_input=self.adapter(extra_input)

                # make cross attention 

                # get the last valid token of seq
                batch_indices = torch.arange(extra_input.shape[0], device=extra_input.device)
                last_seq_tokens_h = hidden_states[batch_indices, sequence_lengths].detach().clone().to(self.attention.q_proj.weight.dtype)  #[B,D]
    
                last_seq_tokens_h=self.attention(last_seq_tokens_h, extra_input)

                if return_last_seq_tokens_h:
                    last_token_h=last_seq_tokens_h 

                if self.cur_fw!=-1:
                    hidden_states[batch_indices, sequence_lengths] = self.cur_fw * last_seq_tokens_h  + (1-min(self.cur_fw, self.fw)) * hidden_states[batch_indices,sequence_lengths]

                    # hidden_states=self.cur_fw*scatter_to_dim(hidden_states, last_seq_tokens_h, sequence_lengths)+(1-min(self.cur_fw, self.fw))*hidden_states
                else:
                    hidden_states[batch_indices, sequence_lengths] = self.fw * last_seq_tokens_h + (1-self.fw) * hidden_states[batch_indices,sequence_lengths]
            
            elif self.agg=="noise":
                hidden_states=add_gaussian_noise(hidden_states,self.fw)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)


        return BaseModelOutputWithPastWithFeedback(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            raw_last_hidden_state=raw_hidden_states,
            last_seq_tokens_h=last_token_h # after conduct cross attention with policy hiddenstates
        )



class ModifiedRewardModel(LlamaForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModelWithExtraInput(config)
        

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        feedback_hidden_states: Optional[torch.FloatTensor] = None,
        sequence_lengths: Optional[torch.Tensor] = None
    ) -> SequenceClassifierOutputWithPast:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """

        transformer_outputs: BaseModelOutputWithPast = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            extra_input=feedback_hidden_states,
            sequence_lengths=sequence_lengths,
        )
        hidden_states = transformer_outputs.last_hidden_state 


        logits = self.score(hidden_states) #[B,S,label]

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            last_non_pad_token = -1
        elif input_ids is not None:
            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        else:
            last_non_pad_token = -1
            logger.warning_once(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] 

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )