import torch
from peft import PeftModelForCausalLM,PeftConfig,PeftType,TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING
import warnings
from typing import Any, Literal, Optional, Union, List
from transformers import DynamicCache
from transformers import (
    AutoConfig, AutoTokenizer, AutoModelForCausalLM, 
    BitsAndBytesConfig, StoppingCriteriaList, StoppingCriteria
)
bias={'lmsys/vicuna-7b-v1.5':1,'meta-llama/Llama-2-7b-chat-hf':2}
def find_all(string, sub):
    start = 0
    pos = []
    while True:
        start = string.find(sub, start)
        if start == -1:
            return pos
        pos.append(start)
        start += len(sub)
def _get_batch_size(input_ids, inputs_embeds) -> int:
    """Get the batch size based on either input_ids or input_embeds

    Raises an ValueError if both are None.

    """
    if (input_ids is None) and (inputs_embeds is None):
        raise ValueError("You have to provide either input_ids or inputs_embeds")

    if input_ids is not None:
        batch_size = input_ids.shape[0]
    else:
        batch_size = inputs_embeds.shape[0]
    return batch_size

class DPromptModel(PeftModelForCausalLM):
    def __init__(
        self,model_id, model: torch.nn.Module, peft_config: PeftConfig,train_bert=False,num_document_token=0,num_virtual_document=1, percentage=0.3, adapter_name: str = "default", **kwargs
    ) -> None:
        super().__init__(model, peft_config, adapter_name, **kwargs)
        print("base_model==",self.base_model)
        document_encoder='google-bert/bert-base-uncased'
        from transformers import BertTokenizer, BertModel
        self.document_encoder_tokenizer = BertTokenizer.from_pretrained(document_encoder)
        self.document_encoder=BertModel.from_pretrained(document_encoder).to('cuda:0')
        if not train_bert:
            for param in self.document_encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.document_encoder.parameters():
                param.requires_grad = True
        self.train_bert=train_bert
        self.percentage=percentage
        from transformers import AutoTokenizer
        self.temp_tokenizer= AutoTokenizer.from_pretrained(model_id,padding_side="left", truncation_side="left",model_max_length=4096)
        self.temp_tokenizer.pad_token = self.temp_tokenizer.eos_token
        self.stopping_criteria = self._define_stopping_criteria()
        self.prefix_encoder=DocumentPromptEncoder(self.peft_config[adapter_name],self.prompt_encoder[self.active_adapter],self.percentage,num_document_token,num_virtual_document).to('cuda:0')
        self.stop_list = ['\nHuman:', '\n```\n', '\nQuestion:', '<|endoftext|>', '\n']
        self.num_document_token=num_document_token
        self.total_virtual_num=self.peft_config[adapter_name].num_virtual_tokens
        self.num_virtual_token=self.total_virtual_num-self.num_document_token
        for name, param in self.base_model.named_parameters():
            if 'lora' in name:
                #print("para name==",name)
                param.requires_grad=True
        
        
    def _define_stopping_criteria(self) -> StoppingCriteriaList:
        """
        Defines stopping criteria for text generation based on the provided stop_list.
        """
        stop_list=['\nHuman:', '\n```\n', '\nQuestion:', '<|endoftext|>', '\n']
        stop_token_ids = [self.temp_tokenizer(x)['input_ids'] for x in stop_list]
        stop_token_ids = [torch.LongTensor(x).to(self.device) for x in stop_token_ids]

        class StopOnTokens(StoppingCriteria):
            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
                for stop_ids in stop_token_ids:
                    if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
                        return True
                return False

        return StoppingCriteriaList([StopOnTokens()])
    
    def get_input_embed(self,input_ids,inputs_embeds,bias):
        batch_size=input_ids.shape[0]
        inputs=[]
        batch_documents=[]
        attention_masks=[]
        for i in range(input_ids.shape[0]):
            from copy import deepcopy
            tmp=deepcopy(input_ids[i]).detach().cpu().numpy()
            tmp=torch.from_numpy(tmp).to(input_ids.device)
            input_origin=self.temp_tokenizer.decode(tmp, skip_special_tokens=True)
            idxr=input_origin.rfind('Question:')
            idxl=input_origin.find('\nDocument [')
            document_str=input_origin[idxl:idxr]
            documents=[]
            idxs=find_all(document_str,'\nDocument')
            idxs.append(len(document_str))
            for i in range(len(idxs)-1):
                documents.append(document_str[idxs[i]:idxs[i+1]])
            batch_documents.append(documents)
        prompt_embedding = self.get_document_prompt(batch_size,batch_documents)
        prompt_embedding = prompt_embedding.to(inputs_embeds.dtype)
        my_input_embeds = torch.cat((prompt_embedding,   inputs_embeds), dim=1)
        return my_input_embeds
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        task_ids=None,
        **kwargs,
    ):

        # for name, param in self.prefix_encoder.named_parameters():
        #     print(name,param)
        #     if param.grad is not None:
        #         print(f"Layer: {name} | Grad Norm: {param.grad.norm()}")
        peft_config = self.active_peft_config
        if not peft_config.is_prompt_learning:
            if self.base_model.config.model_type == "mpt":
                if inputs_embeds is not None:
                    raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds")
                return self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

            if peft_config.peft_type == PeftType.POLY:
                kwargs["task_ids"] = task_ids

            with self._enable_peft_forward_hooks(**kwargs):
                kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                return self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    inputs_embeds=inputs_embeds,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

        batch_size = _get_batch_size(input_ids, inputs_embeds)
        
        
        
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        if kwargs.get("token_type_ids", None) is not None:
            warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
            kwargs["token_type_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(
                input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs
            )
        else:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            # concat prompt labels
            if labels is not None:
                prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
                kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            my_input_embeds=self.get_input_embed(input_ids,inputs_embeds,bias=1)
            return self.base_model(inputs_embeds=my_input_embeds, **kwargs)
            
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
    def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] = None, **kwargs):
        peft_config = self.active_peft_config
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)

        # https://github.com/huggingface/transformers/pull/26681/ introduced new cache format
        # for some architectures which requires a special fix for prompt tuning etc.
        # TODO: starting with transformers 4.38, all architectures should support caching.
        import packaging.version
        import transformers
        uses_transformers_4_38 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.38.0")
        uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0")
        transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"]
        uses_cache = uses_transformers_4_38 or (
            uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs
        )
        #print('test==',model_kwargs["past_key_values"],model_kwargs["past_key_values"].__len__())
        #print("test==",model_kwargs["past_key_values"])
        if peft_config.peft_type == PeftType.POLY:
            model_kwargs["task_ids"] = task_ids
        if peft_config.is_prompt_learning:
            
            if uses_cache and (model_kwargs["past_key_values"].__len__()!=0):
                # change in the logic of `prepare_inputs_for_generation` makes the below code necessary
                # In prompt learning methods, past key values are longer when compared to the `input_ids`.
                # As such only consider the last input ids in the autogressive generation phase.
                past_key_values = model_kwargs["past_key_values"]
                if isinstance(past_key_values, (tuple, list)):
                    seq_len = past_key_values[0][0].shape[-2]
                else:  # using transformers kv cache
                    seq_len = past_key_values.get_seq_length()
                if seq_len >= model_kwargs["input_ids"].shape[1]:
                    model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:]

            if model_kwargs.get("attention_mask", None) is not None:
                size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
                prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
                model_kwargs["attention_mask"] = torch.cat(
                    (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
                )

            if model_kwargs.get("position_ids", None) is not None:
                warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
                model_kwargs["position_ids"] = None

            if kwargs.get("token_type_ids", None) is not None:
                warnings.warn(
                    "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
                )
                kwargs["token_type_ids"] = None
            
            if model_kwargs["past_key_values"].__len__()==0 and peft_config.peft_type == PeftType.PREFIX_TUNING:

                past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])

                model_kwargs["past_key_values"] = past_key_values
            else:
                if model_kwargs["past_key_values"].__len__()==0:
                    inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
                    my_input_embeds=self.get_input_embed(model_kwargs["input_ids"],inputs_embeds,bias=1)
                    
                    model_kwargs["inputs_embeds"] = my_input_embeds
                    model_kwargs["input_ids"] = None
        # For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is
        # passed in the forward pass to keep track of the position ids of the cache. We have to
        # pop that from `model_kwargs` as `cache_position` is properly created by the model, using the passed
        # `inputs_embeds`: https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
        _ = model_kwargs.pop("cache_position", None)

        return model_kwargs
    def generate(self, *args, **kwargs):
        peft_config = self.active_peft_config
        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
        if hasattr(self.base_model, "model"):
            self.base_model.model.generation_config = self.generation_config
        else:
            self.base_model.generation_config = self.generation_config
        try:
            if not peft_config.is_prompt_learning:
                with self._enable_peft_forward_hooks(*args, **kwargs):
                    kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                    outputs = self.base_model.generate(*args, **kwargs)
            else:
                outputs = self.base_model.generate(**kwargs)
        except:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            raise
        else:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            return outputs

    
    def generate_answer(self,prompt,tokenizer):
        inputs = tokenizer(
            prompt, 
            padding=True, 
            truncation=True, 
            max_length=4096, 
            return_tensors="pt",
        ).to(self.device)
        
        generated_ids = self.generate(
            **inputs,
            do_sample=False,
            max_new_tokens=30,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    def get_document_prompt(self, batch_size: int, document=None,) -> torch.Tensor:
        peft_config = self.active_peft_config
        prompt_encoder = self.prompt_encoder[self.active_adapter]
        prompt_tokens = (
            self.prompt_tokens[self.active_adapter]
            .unsqueeze(0)
            .expand(batch_size, -1)
            .to(prompt_encoder.embedding.weight.device)
        )
        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
            if peft_config.inference_mode:
                past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
            else:
                past_key_values = prompt_encoder(prompt_tokens)
            if self.base_model_torch_dtype is not None:
                past_key_values = past_key_values.to(self.base_model_torch_dtype)
            past_key_values = past_key_values.view(
                batch_size,
                peft_config.num_virtual_tokens,
                peft_config.num_layers * 2,
                peft_config.num_attention_heads,
                peft_config.token_dim // peft_config.num_attention_heads,
            )
            if peft_config.num_transformer_submodules == 2:
                past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
            past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
                peft_config.num_transformer_submodules * 2
            )
            if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
                post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
                past_key_values = post_process_fn(past_key_values)
            return past_key_values
        else:

            if peft_config.inference_mode:
                prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1 )
            else:
                from copy import deepcopy
                documents=[]
                for i in range(len(document)):
                    document_token=self.document_encoder_tokenizer(document[i],return_tensors='pt',padding="max_length",truncation=True, max_length=512).to('cuda:0')

                    document_input=self.document_encoder(**document_token).pooler_output
                    documents.append(document_input)
                documents=torch.stack(documents,dim=0)
                
                device=self.prefix_encoder.prompt_encoder.embedding.weight.device
                documents=documents.to(device)
                prompt_embedding = self.prefix_encoder(prompt_tokens,documents)
            return prompt_embedding
    
    def save_pretrained(
        self,
        save_directory: str,
        safe_serialization: bool = True,
        selected_adapters = None,
        save_embedding_layers: Union[str, bool] = "auto",
        is_main_process: bool = True,
        convert_pissa_to_lora: Optional[str] = None,
        path_initial_model_for_weight_conversion: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        import os
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)
        #print('stat==',self.prefix_encoder.state_dict())
        torch.save(self.prefix_encoder.state_dict(), f'{save_directory}/prefix_encoder_params.pth')
        if self.train_bert:
            torch.save(self.document_encoder.state_dict(),f'{save_directory}/document_encoder_params.pth')
        #torch.save(self.prompt_encoder[self.active_adapter].state_dict(), f'{save_directory}/prompt_encoder_params.pth')
        lora_paras={}
        for name, param in self.base_model.named_parameters():
            if 'lora' in name:
                lora_paras[name]=param.data
        torch.save(lora_paras,f'{save_directory}/lora_params.pth')
                
                
                
    def load_prefix_encoder(self,path):
        import os
        file=os.path.join(path,'prefix_encoder_params.pth')
        state_dict=torch.load(file)
        print("state dict==",state_dict)
        self.prefix_encoder.load_state_dict(state_dict)
        if self.train_bert:
            file=os.path.join(path,'document_encoder_params.pth')
            state_dict=torch.load(file)
            
            self.document_encoder.load_state_dict(state_dict)
        file=os.path.join(path,'lora_params.pth')
        lora_paras=torch.load(file,map_location=torch.device('cuda'))
        for name, param in self.base_model.named_parameters():
            if 'lora' in name:
                param.data = lora_paras[name].data


class DocumentPromptEncoder(torch.nn.Module):
    def __init__(self, config, prompt_encoder,percentage,num_document_token,num_virtual_document):
        super().__init__()

        token_dim = config.token_dim
        self.num_virtual_tokens = config.num_virtual_tokens
        self.prompt_encoder=prompt_encoder

        self.percentage=percentage
        self.num_document_tokens=num_document_token
        self.token_dim=token_dim

        self.num_virtual_document=num_virtual_document
        if self.num_virtual_document>0:
            self.fc=torch.nn.Linear(768,1024)
            self.activation=torch.nn.Tanh()
            self.fc_out=torch.nn.Linear(1024,1024)
            self.fcs=torch.nn.ModuleList()
            for _ in range(self.num_virtual_document):
                self.fcs.append(torch.nn.Linear(1024*self.num_document_tokens,token_dim))
        
    def forward(self, indice: torch.Tensor, document):
        if self.num_virtual_document==0:
            return self.prompt_encoder(indice)

        document_embedding=self.fc_out(self.activation(self.fc(document)))

        num_prompt_tokens=self.num_virtual_tokens-self.num_virtual_document
        document_embedding=document_embedding.reshape(-1,self.num_document_tokens*1024,1).squeeze(-1)
        document_embeddings=[]

        for fc in self.fcs:
            document_embeddings.append(fc(document_embedding))
        document_embedding=torch.stack(document_embeddings,dim=-2)

        prompt_embedding=self.prompt_encoder(indice)
        
        prompt_embedding=prompt_embedding[:,0:num_prompt_tokens,:]

        prompt_embedding=torch.cat([prompt_embedding,document_embedding],dim=-2)
        
        return prompt_embedding

class MyPrefixEncoder(torch.nn.Module):

    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        token_dim = config.token_dim
        num_layers = config.num_layers
        encoder_hidden_size = config.encoder_hidden_size
        num_virtual_tokens = config.num_virtual_tokens

        if self.prefix_projection and not config.inference_mode:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
            self.transform = torch.nn.Sequential(
                torch.nn.Linear(token_dim, encoder_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
            )
            self.transform_document = torch.nn.Sequential(
                torch.nn.Linear(768, encoder_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
            )
        else:
            self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)

    def forward(self, prefix: torch.Tensor, document):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.transform(prefix_tokens)
            document_past_key_values = self.transform_document(document)
            for i in range(past_key_values.shape[0]):
                past_key_values[i][0]=document_past_key_values[i]
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values
