from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.utils.cuda import get_device
from opendelta.basemodel import DeltaBase
from typing import *
import torch
import torch.nn as nn
from opendelta import BaseDeltaConfig
from decorator import decorate
import torch.nn.functional as F
from opendelta import logging
logger = logging.get_logger(__name__)

class SoftPromptConfig(BaseDeltaConfig):
    r"""
    This is the configuration class to store the configuration of a :py:class:`SoftPromptModel`

    """
    def __init__(
        self, 
        soft_token_num=100,
        init_range = 0.5,
        token_init = True,
        **kwargs
    ): 
        super().__init__(**kwargs)
        arg_names = get_arg_names_inside_func(self.__init__)
        for arg_name in arg_names:
            if not hasattr(self, arg_name): # the arg has not been registered in parent config
                setattr(self, arg_name, locals()[arg_name])

        

class SoftPromptLayer(nn.Module):
    r"""This is the implementation of `The Power of Scale for Parameter-Efficient
    Prompt Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ . Similar to :obj:`PrefixTuningTemplate`,
    This template also does not need any textual template. Addition tokens are directly
    concatenated into the input ids. There are two initializations of the new tokens. 
    (1). random initialization. (2) initialize with the tokens of the plm (We simply take 
    the first n_tokens similar to their implementation).

    Note that this template can be simply achieved by :obj:`SoftManualTemplate`, in which
    you set ``n_token`` <soft> tokens template before the <text_a> will give the same result.
    """

    def __init__(self,
                 soft_token_num: int = 100,
                 raw_embedding: Optional[torch.Tensor] = None,
                 init_range: Optional[float] = 0.5,
                 token_init = False,
                 pad_id = 0,
                 device: Optional[str]=None,
                ):
        super().__init__()
        self.__dict__['raw_embedding'] = raw_embedding
        
        self.init_range = init_range
        self.num_tokens = soft_token_num
        self.pad_id = pad_id
        self.token_init = token_init
        self.device = device
    
        assert self.num_tokens>0
        self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])

    def pre_forward(self, *args, **kwargs):
        # if attention_mask is passed as PLM's input, modify it here
        if 'encoder_outputs' in kwargs and kwargs['encoder_outputs'] is not None: 
            # In generation, the input is forward through the model again.
            return args, kwargs

        if 'input_ids' in kwargs:
            input_ids = kwargs['input_ids']
            kwargs['input_ids'] = None
        elif len(args) > 0:
            input_ids = args[0]
            args = args[1:]
        else:
            input_ids = None


        if 'attention_mask' not in kwargs or kwargs['attention_mask'] is None:
            # infer attention mask
            if input_ids is None:
                raise RuntimeError("no input ids found")
            kwargs['attention_mask'] = (input_ids != self.pad_id).to(torch.int64)
            
        if 'inputs_embeds' not in kwargs or kwargs['inputs_embeds'] is None:
            try:
                inputs_embeds = self.raw_embedding(input_ids)
            except:
                raise RuntimeError("neither inputs_embeds nor input_ids is specified.")
        else:
            inputs_embeds = kwargs['inputs_embeds']
            
        
        
        batch_size = inputs_embeds.size(0)
        soft_embeds = self.soft_embeds.repeat(batch_size, 1, 1)
        inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1)
        kwargs['inputs_embeds'] = inputs_embeds

        am = kwargs['attention_mask']
        am.data = torch.cat([torch.ones((*am.shape[:-1], inputs_embeds.shape[-2]-am.shape[-1]), dtype = am.dtype,device=am.device), am], dim=-1)

        return args, kwargs

    def instantiate(self, hidden_dim) -> None:
        """
        generate parameters needed for soft tokens embedding in soft-prompt
        for soft tokens, use a new embedding layer which is initialized with their corresponding embedding of hard tokens
        """
        soft_embeds = torch.FloatTensor(self.num_tokens, hidden_dim)
        if self.token_init:
            soft_embeds.data = torch.clone(self.raw_embedding(torch.tensor([i for i in range(self.num_tokens)])))
        else:
            soft_embeds = soft_embeds.uniform_(-self.init_range, self.init_range)

        self.soft_embeds = nn.Parameter(soft_embeds, requires_grad=True).to(self.device)


class SoftPromptModel(DeltaBase):
    r"""
    This is the implementation of `The Power of Scale for Parameter-Efficient
    Prompt Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ . Similar to :obj:`PrefixTuningTemplate`,
    This template also does not need any textual template. Addition tokens are directly
    concatenated into the input ids. There are two initializations of the new tokens. 
    (1). random initialization. (2) initialize with the tokens of the plm (We simply take 
    the first n_tokens similar to their implementation).

    Note that this template can be simply achieved by :obj:`SoftManualTemplate`, in which
    you set ``n_token`` <soft> tokens template before the <text_a> will give the same result.

    Args:
        backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified. 
        soft_token_num (:obj:`int`, *optional*): num of new tokens to add in the front of the input.
        init_range (:obj:`float`, *optional*): If initialize new tokens randomly, the random range of uniform distribution.
        token_init (:obj:`bool`, *optional*, default to :obj:`True`): Whether to initialize the new tokens with tokens of the plm
        modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
                        the implemented ones)
        unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
                         together with the prefix parameters.
        common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.

    """
    config_class = SoftPromptConfig
    delta_type = "soft_prompt"
    default_modified_modules = ["root"]  # not used
    def __init__(self, 
                 backbone_model: nn.Module,
                 soft_token_num=100,
                 init_range = 0.5,
                 token_init=True,
                 modified_modules: Optional[List[str]] = None,
                 unfrozen_modules: Optional[List[str]] = None,
                 common_structure: Optional[bool] = None,
                 interactive_modify: Optional[Union[bool, int]] = False,
                ):
        DeltaBase.__init__(self, 
                           backbone_model = backbone_model,
                           modified_modules = ["root"],
                           unfrozen_modules = unfrozen_modules,
                           common_structure = False,
                           interactive_modify = interactive_modify,
                          )

        arg_names = get_arg_names_inside_func(self.__init__)
        for arg_name in arg_names:
            if not hasattr(self, arg_name): # not registered in parent class
                setattr(self, arg_name, locals()[arg_name])


        try:
            self.__dict__['raw_embedding'] = self.backbone_model.get_input_embeddings()
        except AttributeError:
            raise AttributeError(f"'{type(self.backbone_model)}' object has no attribute 'get_input_embeddings', please pass "+
            "input embeddings into 'self.raw_embedding' in user-specific ways.")
        
        self.delta_modules = nn.ModuleList()
        self.add_all_delta_to_backbone(self.backbone_model,
                                       self.modified_modules,
                                      )

    def add_all_delta_to_backbone(self, 
                 module: nn.Module, 
                 modified_modules: List[str],
                ) -> nn.Module:
        self.update_module()
        self.mark_as_delta()
        return module
    
    def update_module(self):
        soft_prompt_layer = self.new_module_like(self.raw_embedding)
        self.insert_sequential_module(self.backbone_model.get_encoder() if self.backbone_model.config.is_encoder_decoder else self.backbone_model,
                                          delta_module=soft_prompt_layer, 
                                          delta_name="soft_prompt_layer"  )

    def new_module_like(self, module):
        module_device = get_device(module)
        soft_prompt_layer = SoftPromptLayer(
            soft_token_num = self.soft_token_num,
            raw_embedding = self.raw_embedding,
            token_init = self.token_init,
            init_range = self.init_range,
            device = module_device,
        )
        self.delta_modules.append(soft_prompt_layer)  
        return soft_prompt_layer
    