from typing import Optional, Union
from opendelta.utils.signature import get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.utils.cuda import get_device
from opendelta.basemodel import DeltaBase
import torch.nn as nn
import torch
from opendelta.delta_models.layers.activations import Activations
from opendelta import BaseDeltaConfig
import opendelta.utils.logging as logging
import numpy as np
import os
from opendelta import global_setting
from dataclasses import dataclass, field

from collections import OrderedDict

logger = logging.get_logger(__name__)


class SoftThresholdActivation(nn.Module):
    def __init__(self, bottleneck_dim, init_device, hidden_dtype, init_threshold):
        super().__init__()
        self.g = nn.Parameter(torch.empty(bottleneck_dim, device=init_device, dtype=hidden_dtype))
        self.threshold = nn.Parameter(torch.empty(1, device=init_device, dtype=hidden_dtype))
        self.init_threshold = init_threshold  

    def forward(self, x):
        abs_g = torch.abs(self.g)
        if abs_g.numel() > 1:  # 确保 g 中有至少两个元素
            second_largest = torch.topk(abs_g.view(-1), k=2).values[-1]  # 取第二大的值
        else:
            second_largest = abs_g.max()  # 如果只有一个元素，直接取最大值

        # 动态调整 threshold，取 min(threshold, second_largest)
        dynamic_threshold = torch.min(abs(self.threshold), second_largest)
        gamma = torch.sign(self.g) * torch.relu(torch.abs(self.g) - dynamic_threshold)

        return x * gamma


class InterFaceMixin:
    def __init__(self):
        self._axis_order = global_setting.axis_order
        self._reverse_axis_order = np.argsort(self._axis_order).tolist()

    def _transpose(self, tensor):
        if tensor.dim() == 3:
            return tensor.permute(*self._axis_order)
        else:
            return tensor

    def _reverse_transpose(self, tensor):
        if tensor.dim() == 3:
            return tensor.permute(*self._reverse_axis_order).contiguous()
        else:
            return tensor

    def _convert_data_type(self, tensor):
        self._data_type_record = tensor.dtype
        self._device_record = tensor.device
        return tensor.to(torch.float32).to(self._get_device())

    def _reverse_data_type(self, tensor):
        return tensor.to(self._data_type_record).to(self._device_record)


class AdapterLayer(nn.Module, InterFaceMixin):
    layer_count = 0

    @classmethod
    def count_layer(cls):
        cls.layer_count += 1

    @classmethod
    def get_layer_count(cls):
        return cls.layer_count
    
    def __init__(self, bottleneck_dim=24, device=None, backend="hf", threshold=1e-4):
        super().__init__()
        InterFaceMixin.__init__(self)
        self.bottleneck_dim = bottleneck_dim
        self.init_device = device
        self.instantiated = False
        self.backend = backend
        self.threshold = threshold
    
        self.layer_id = AdapterLayer.get_layer_count()
        AdapterLayer.count_layer()

    def _get_device(self):
        if self.instantiated:
            return self.modulelist.down_proj.weight.device
        else:
            return self.init_device

    def instantiate(self, hiddens):
        self.hidden_dim = hiddens.shape[-1]
        self.hidden_dtype = hiddens.dtype
        
        # 使用Sequential构建模块链
        self.modulelist = nn.Sequential(OrderedDict([
            ('down_sample', nn.Linear(self.hidden_dim, 
                                      self.bottleneck_dim, 
                                      device=self.init_device, 
                                      dtype=self.hidden_dtype, 
                                      bias=False)
            ),
            ('sparse_activation', SoftThresholdActivation(self.bottleneck_dim, 
                                                          self.init_device, 
                                                          self.hidden_dtype,
                                                          self.threshold)
            ),
            ('up_sample', nn.Linear(self.bottleneck_dim, 
                                    self.hidden_dim, 
                                    device=self.init_device, 
                                    dtype=self.hidden_dtype, 
                                    bias=False)
            )
        ]))
        
        self.instantiated = True
        
        # 应用权重初始化
        self.apply(self._init_weight)
        
        
    def _init_weight(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.01)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, SoftThresholdActivation):
            nn.init.normal_(module.g, std=0.01)
            nn.init.constant_(module.threshold, module.init_threshold)  # 使用传入的初始值
        
    def post_forward(self, output):
        # 提取隐藏状态
        if isinstance(output, tuple):
            hiddens = output[0]
        else:
            hiddens = output
        
        # 维度转换
        hiddens = self._transpose(hiddens)
        
        # 延迟初始化
        if not self.instantiated:
            self.instantiate(hiddens)
        
        # 通过模块链计算适配器输出
        adapter_output = self.modulelist(hiddens)
        
        # 残差连接
        modified_output = adapter_output + hiddens
        
        # 维度还原
        modified_output = self._reverse_transpose(modified_output)
        
        # 保持原始输出结构
        if isinstance(output, tuple):
            output = (modified_output,) + output[1:]
        elif isinstance(output, torch.Tensor):
            output = modified_output
        else:
            raise TypeError
        return output


class AdapterConfig(BaseDeltaConfig):
    r"""
    This is the configuration class to store the configuration of a :py:class:`~AdapterModel`

    """
    def __init__(
        self,
        bottleneck_dim: Optional[int]=24,
        threshold: Optional[float] = 1e-4,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.bottleneck_dim = bottleneck_dim
        self.threshold = threshold
        arg_names = get_arg_names_inside_func(self.__init__)
        for arg_name in arg_names:
            if not hasattr(self, arg_name): # the arg has not been registered in parent config
                setattr(self, arg_name, locals()[arg_name])



class AdapterModel(DeltaBase):
    r""" The implementation of Adapter(`Parameter-Efficient Transfer Learning for NLP <https://arxiv.org/abs/1902.00751>`_ ) .
    Add adapter to the designated ``modified_modules``. In sequential paradigm, The modules' output is then passed into the adapter's
    post_forward.

    .. note::
        We **assume** the output of the modified module is the hidden state or a tuple where hidden state is the
        first element. This is true for most PLMs. However, we admit that currently it's not rigorous, We will improve
        it in the next version. Currently, if you encount an error here for you backbone, you can modify the code to
        get the hidden state.

    class attributes:
        - default_modified_modules = ["attn", "ff"] According to the Adapter paper, we add adapter to the attention layer
          and feed forward layer.
        - delta_type = "adapter"

    Args:
        backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
        bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck.
        modified_modules (:obj:`List[str]`): modules to add adapter after them.
        unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the adapter parameters.
        common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
        backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain. 

    """
    config_class = AdapterConfig
    delta_type = "adapter"
    default_modified_modules = ["attn@.proj@", "ff@.w2@"]
    _supported_backends = ['hf', 'bmt']
    _need_pseudo_data = True
    def __init__(self,
                 backbone_model: nn.Module,
                 bottleneck_dim: Optional[int]=24,
                 threshold: Optional[float] = 1e-4,
                 modified_modules: Optional[List[str]] = None,
                 exclude_modules: Optional[List[str]] = None,
                 unfrozen_modules: Optional[bool] = None,
                 common_structure: Optional[bool] = None,
                 interactive_modify: Optional[Union[bool, int]] = False,
                 backend: Optional[str] = 'hf',
                 ):
        DeltaBase.__init__(self,
                           backbone_model,
                           modified_modules=modified_modules,
                           exclude_modules=exclude_modules,
                           unfrozen_modules=unfrozen_modules,
                           common_structure=common_structure,
                           interactive_modify=interactive_modify,
                           backend=backend,
                           )
        arg_names = get_arg_names_inside_func(self.__init__)
        for arg_name in arg_names:
            if not hasattr(self, arg_name): # not registered in parent class
                setattr(self, arg_name, locals()[arg_name])

        self.delta_modules = nn.ModuleList()

        self.add_all_delta_to_backbone(self.backbone_model,
                                   self.modified_modules,
                                   )
  
    
    def update_module(self, module: nn.Module, key: str):
        _, _, ref = self.find_module(module, key)
      
        # 为已完成任务添加 LoRA 模块
        done_tasks_name = os.getenv("DONE_TASKS_NAME", "")
        if done_tasks_name:
            done_tasks_name = done_tasks_name.split(',')
            for done_task_name in done_tasks_name:
                # # 检查任务名称是否在支持的任务列表中
                # if done_task_name in {"amazon", "dbpedia", "yahoo", "agnews"}:
                delta_name = f"{done_task_name}_adapter"
                # 为已完成任务创建新的 LoRA 模块
                adapterlayer = self.new_module_like(ref)
                # 插入并行模块
                self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name=delta_name)

        # 为当前任务添加 LoRA 模块
        current_task_name = os.getenv("CURRENT_TASK_NAME", "current_task")
        delta_name = f"{current_task_name}_adapter"
        # 为当前任务创建新的 LoRA 模块
        adapterlayer = self.new_module_like(ref)
        # 插入并行模块
        self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name=delta_name)

    def new_module_like(self, module):
        module_device = get_device(module)
        adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, device=module_device, backend=self.backend, threshold=self.threshold)
        # adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, device=module_device, backend=self.backend, threshold=init_threshold)
        self.delta_modules.append(adapterlayer)
        return adapterlayer