from transformers import LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel, LlamaForCausalLM, AutoModelForCausalLM
import torch
from torch import nn
import torch.nn.functional as F
from copy import deepcopy
from typing import List, Optional, Tuple, Union
from torch.utils.data import Dataset, DataLoader
from backpack import backpack, extend
from backpack.extensions import BatchGrad
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.utils import (
    add_start_docstrings_to_model_forward,
    replace_return_docstrings,
)


LLAMA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)
        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

            Two formats are allowed:
            - a [`~cache_utils.Cache`] instance;
            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
            cache format.

            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
            legacy cache format will be returned.

            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
            of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
            the complete sequence length.
"""

_CONFIG_FOR_DOC = "LlamaConfig"

#############################################################################################################################
#############################################################################################################################
#############################################################################################################################

tkwargs = {
    # "device": torch.device("cuda:0"),
    "device": torch.device("cuda"),
    "dtype": torch.double,
}


class LlamaForAIO(LlamaForCausalLM):

    def __init__(self, config, max_sequence_len_est=1000):
        super().__init__(config)

        ##### For approximating the black-box LLM gradients
        self.logits_perturbation_flag = False
        self.apply_black_box_proxy_m_flag = False
        self.vocab_size = config.vocab_size
        self.config = config
        self.target_dtype = self.lm_head.weight.data.dtype
        self.max_sequence_len_est = max_sequence_len_est
        #
        self.random_gaussian_vec = None

    def set_training_args(self, training_args):
        self.training_args = training_args

    @torch.no_grad()
    def get_last_token_hidden_state(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        sequence_lengths: Optional[int] = None,
        n_prompt_tokens: Optional[int] = 0,
        pooling: Optional[str] = "last",
    ) -> Tuple:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        #
        with torch.cuda.amp.autocast():
            # if input_ids is not None:
            #     inputs_embeds = input_ids.to(self.model.device)
            # if inputs_embeds is not None:
            #     inputs_embeds = inputs_embeds.to(self.model.device)
            #
            transformer_outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            hidden_states = transformer_outputs[0]

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]
       
        #
        if sequence_lengths is None:
            if self.config.pad_token_id is None:
                sequence_lengths = -1
            else:
                if input_ids is not None:
                    sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + n_prompt_tokens).to(hidden_states.device)
                else:
                    sequence_lengths = -1
        
        # ----
        if pooling == "last":
            pooled_states = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths]
        elif pooling == "mean":
            pooled_states = hidden_states.mean(dim=1)
        elif pooling == "max":
            pooled_states = hidden_states.max(dim=1).values
        return (pooled_states,)

    ####################################################################################

    def register_TS_instance(self, TS_model):
        self.TS_model = TS_model

    def refresh_gaussian_vec(self, random_seed):
        torch.manual_seed(random_seed)
        #
        if self.training_args.TS_aided_grad_approx:
            num_candidate_arms = self.training_args.TS_candidate_arms
            #
            candidate_items = torch.normal(mean=0, std=1, size=(num_candidate_arms, self.vocab_size))
            aug_candidate_arms = torch.cat([(-1 * candidate_items.detach().clone()), candidate_items], dim=0)
            #
            ranking_scores = self.TS_model.get_ranking_score(items=aug_candidate_arms)
            
            # : check how to determining the ranking scores
            if self.training_args.TS_single_direction_reward:
                ranking_scores = -1 * ranking_scores
            else:
                ranking_scores = torch.abs(ranking_scores)
            
            ###
            if self.training_args.zo_approx_steps > 1:
                top_k = self.training_args.zo_approx_steps
                indices = torch.topk(ranking_scores, top_k).indices
                #
                self.random_gaussian_m = aug_candidate_arms[indices, :]
            else:
                indice = torch.topk(ranking_scores, 1).indices
                self.random_gaussian_vec = aug_candidate_arms[indice, :].view(-1, )
        else:
            if self.training_args.zo_approx_steps > 1:
                self.random_gaussian_m = torch.normal(mean=0, std=1, size=(self.training_args.zo_approx_steps, self.vocab_size))
            else:
                self.random_gaussian_vec = torch.normal(mean=0, std=1, size=(self.vocab_size, ))

    def manually_update_gaussian_vec(self, gaussian_vec):
        self.random_gaussian_vec = gaussian_vec

    def set_epsilon_val_and_direction(self, epsilon_val, perturb_direction):
        self.epsilon_val = epsilon_val
        self.perturb_direction = perturb_direction

    def set_logits_perturbation_flag(self, logits_perturbation_flag):
        self.logits_perturbation_flag = logits_perturbation_flag
    
    def set_apply_black_box_proxy_m_flag(self, apply_black_box_proxy_m_flag):
        self.apply_black_box_proxy_m_flag = apply_black_box_proxy_m_flag

    def inject_data_into_proxy_weight_m(self, epsilon_val, loss_diff, min_coef=0.001, inject_data_vec=None):
        with torch.no_grad():
            if inject_data_vec is None:
                if self.training_args.zo_approx_steps > 1:
                    coefficient = loss_diff / (2 * epsilon_val)
                    #
                    if min_coef > 0:
                        coefficient[(coefficient == 0).nonzero()] = min_coef
                    print("[Gaussian Vec Coef]: ", coefficient)
                    #
                    coef_diag_m = torch.diag(coefficient)
                    self.scaled_gaussian_multiplier_m = torch.matmul(coef_diag_m, self.random_gaussian_m)
                else:
                    coefficient = loss_diff / (2 * epsilon_val)
                    if min_coef > 0:
                        if coefficient == 0:
                            scaled_coefficient = min_coef
                        else:
                            scaled_coefficient = coefficient
                    print("[Gaussian Vec Coef]: ", scaled_coefficient)
                    self.scaled_gaussian_multiplier_vec = scaled_coefficient * self.random_gaussian_vec.to(self.target_dtype)
            else:
                self.scaled_gaussian_multiplier_vec = inject_data_vec.to(self.target_dtype).to(self.device)

    def set_modules_require_gradients(self, training_args, module_names=None):
        if not training_args.only_tune_soft_prompt_embedding and not training_args.lora:
            # 'lm_head.weight'
            if module_names is not None:
                for name, param in self.named_parameters():
                    if name in module_names:
                        param.requires_grad_(True)
                    else:
                        param.requires_grad_(False)
            else:
                for name, param in self.named_parameters():
                    param.requires_grad_(True)

    ####################################################################################
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )
        
        #
        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        
        ############################# Forward with Injected Gradients #############################

        if self.apply_black_box_proxy_m_flag:
            #
            print("[Original last hidden]: ", hidden_states.shape)
            print("[Original last hidden Finite Check]: ", torch.isfinite(hidden_states).all())
            #
            print("[Original Forward Logits]: ", logits.shape)
            print("[Original Forward Logits Finite Check]: ", torch.isfinite(logits).all())
            ############################

            labels = None
            assert len(logits.shape) == 3 and logits.shape[0] == 1
            ############################################################################
            
            #
            if self.training_args.zo_approx_steps > 1:

                # einsum ver.
                avg_gaussian_m = self.scaled_gaussian_multiplier_m[:, :logits.shape[-1]].T.to(logits.device).to(logits.dtype)
                logits = torch.einsum('aij,jk->aik', logits, avg_gaussian_m)
                # Mean value for token-level
                logits = torch.mean(logits, dim=2)
                # Sum value for batch-level
                logits = torch.sum(logits, dim=1)

                ###
                print("[AIO_compute_loss Logits]: ", logits)

            else:
                logits_vec = logits.view(-1, )
                gaussian_multiplier_vec = \
                    self.scaled_gaussian_multiplier_vec[:logits.shape[-1]].repeat([logits.shape[0], logits.shape[1], 1]).view(-1, ).to(logits.device)
                logits = torch.inner(logits_vec, gaussian_multiplier_vec)

                ###
                print("[AIO_compute_loss Logits]: ", logits)
                print("[Gaussian_multiplier_vec]: ", gaussian_multiplier_vec)
                print("[Logit vec]: ", logits_vec)

        ############################# Add Gaussian perturbation #############################
        if self.logits_perturbation_flag:
            this_gaussian_perturb_matrix = \
                self.random_gaussian_vec[:logits.shape[-1]].repeat([logits.shape[0], logits.shape[1], 1]).reshape(logits.shape).to(logits.device)
            logits = logits + (self.perturb_direction * self.epsilon_val * this_gaussian_perturb_matrix)

        #####################################################################################
        #
        assert self.apply_black_box_proxy_m_flag + self.logits_perturbation_flag < 2
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

