import torch
from peft.tuners.lora import LoraLayer
import torch.nn as nn
import torch.nn.functional as F
import math
from peft.utils.other import transpose
import random
import numpy as np
import peft.tuners.lora as Lora
from peft import PeftModel
# from peft.tuners.lora import Linear, Embedding, LoraModel
import torch
import pynvml
import os
import warnings
from torch.autograd import Function
from torch.utils.data import IterableDataset
import netifaces
from trl import SFTTrainer
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from transformers.trainer import TRAINING_ARGS_NAME
import safetensors.torch
from functools import partial
from torch.optim.lr_scheduler import LambdaLR
from dataclasses import dataclass, field, asdict
from transformers import TrainerCallback, TrainerState
from collections import defaultdict
import time
import torch.distributed as dist
from sparse_topology_initialization import *
import torch.nn.functional as F


@dataclass
class LoraCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"
    args: Optional[any]
    last_lora_A: defaultdict = field(default_factory=lambda: defaultdict(lambda: None))
    last_lora_B: defaultdict = field(default_factory=lambda: defaultdict(lambda: None))
    optimizer_state_list: list = field(default_factory=list)

    def __post_init__(self):
        if self.args.no_clear_step is True:
            self.optimizer_state_list = ["state1", "state2"]
        else:
            self.optimizer_state_list = ["step", "state1", "state2"]

    def on_step_end(self, args, state: TrainerState, control, model=None, optimizer=None, **kwargs):
        print(f"current step: {state.global_step}")
        with torch.no_grad():
            for n, m in model.named_modules():
                if isinstance(m, SVDLinear):
                    m.lora_active_A[m.active_adapter[0]].data *= torch.rsqrt(m.lora_active_A[m.active_adapter[0]].pow(2).sum(1, keepdim=True) + 1e-9)
                    m.lora_active_B[m.active_adapter[0]].data *= torch.rsqrt(m.lora_active_B[m.active_adapter[0]].pow(2).sum(0, keepdim=True) + 1e-9)
                    m.lora_E[m.active_adapter[0]].data.clamp_(min=0)

        if state.global_step >= self.args.update_start and state.global_step % self.args.update_interval == 0:
            merge_refresh(model, self.args, optimizer, self.last_lora_A, self.last_lora_B, self.optimizer_state_list)
        if self.args.init == "momentum" and state.global_step % self.args.update_interval == self.args.update_interval // 2:
            print(f"record lora")
            for n, m in model.named_modules():
                if isinstance(m, LoraLayer):
                    # print(f"lora_A: {m.lora_A[m.active_adapter].weight.data}")
                    # print(f"lora_B: {m.lora_B[m.active_adapter].weight.data}")
                    self.last_lora_A[n] = m.lora_A[m.active_adapter].weight.clone().detach()
                    self.last_lora_B[n] = m.lora_B[m.active_adapter].weight.clone().detach()

def get_local_ip():
    for interface in netifaces.interfaces():
        addresses = netifaces.ifaddresses(interface)
        # Check for IPv4 addresses
        if netifaces.AF_INET in addresses:
            for link in addresses[netifaces.AF_INET]:
                ip_address = link['addr']
                # Assuming your local network IPs start with '192.168.'
                if ip_address.startswith('192.168.'):
                    return ip_address
    return "Local IP address not found"

def get_lora_parameter_names(model):
    result = []
    for param_name, param in model.named_parameters():
        if "lora" in param_name:
            result.append(param_name)
    return result

def _get_iterative_polynomial_decay_schedule_with_warmup_lr_lambda(
    current_step: int,
    *,
    num_warmup_steps: int,
    num_training_steps: int,
    update_interval: int,
    num_warmup_per_interval: int,
    lr_end: float,
    power: float,
    lr_init: int,
):
    if current_step < num_warmup_steps:
        global_lr = float(current_step) / float(max(1, num_warmup_steps))
    elif current_step > num_training_steps:
        global_lr = lr_end / lr_init  # as LambdaLR multiplies by lr_init
    else:
        lr_range = lr_init - lr_end
        decay_steps = num_training_steps - num_warmup_steps
        pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
        decay = lr_range * pct_remaining**power + lr_end
        global_lr = decay / lr_init  # as LambdaLR multiplies by lr_init

    local_step = current_step % update_interval
    if local_step < num_warmup_per_interval:
        local_lr = local_step / num_warmup_per_interval
    else:
        local_lr = 1

    return global_lr * local_lr

def get_iterative_polynomial_decay_schedule_with_warmup(
    optimizer, num_warmup_steps, num_training_steps, update_interval, num_warmup_per_interval, lr_end=1e-7, power=1.0, last_epoch=-1
):
    """
    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
    initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        lr_end (`float`, *optional*, defaults to 1e-7):
            The end LR.
        power (`float`, *optional*, defaults to 1.0):
            Power factor.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
    implementation at
    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    lr_init = optimizer.defaults["lr"]
    # if not (lr_init > lr_end):
    #     raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")

    lr_lambda = partial(
        _get_iterative_polynomial_decay_schedule_with_warmup_lr_lambda,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        update_interval=update_interval,
        num_warmup_per_interval=num_warmup_per_interval,
        lr_end=lr_end,
        power=power,
        lr_init=lr_init,
    )
    return LambdaLR(optimizer, lr_lambda, last_epoch)

def _get_warm_iterative_cosine_lr_lambda(
    current_step: int, *, steps_per_cycle: int
):
    current_cycle_step = current_step % steps_per_cycle
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * (float(current_cycle_step / steps_per_cycle) + 1.0))))

def get_warm_iterative_cosine(
    optimizer, steps_per_cycle: int, last_epoch: int = -1
):
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
    linearly between 0 and the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_cycles (`int`, *optional*, defaults to 1):
            The number of hard restarts to use.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    lr_lambda = partial(
        _get_warm_iterative_cosine_lr_lambda,
        steps_per_cycle=steps_per_cycle
    )
    return LambdaLR(optimizer, lr_lambda, last_epoch)

def replace_with_svd(model, config, lora_head, mode):
    for n, m in model.named_modules():
        if isinstance(m, Lora.Linear):
            parent = model.get_submodule(".".join(n.split(".")[:-1]))
            child_name = n.split(".")[-1]
            active_adapter = m.active_adapter[0]
            new_module = SVDLinear(m.base_layer, active_adapter, config.r, lora_head, config.lora_alpha, config.lora_dropout, init_lora_weights=False, mode=mode)
            model._replace_module(parent, child_name, new_module, m)
            new_module.to(m.weight.device)
        elif isinstance(m, Lora.Embedding):
            raise Exception("no svd embedding")

class SVDFunction(Function):
    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(input, A, E, B):
        output = input @ (A * E).T @ B.T
        return output

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, A, E, B = inputs
        ctx.save_for_backward(input, A, E, B)

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, A, E, B = ctx.saved_tensors
        grad_input = grad_A = grad_E = grad_B = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output @ B @ (A * E)
        if ctx.needs_input_grad[1]:
            grad_A = ((input.T @ grad_output @ B)).T
        if ctx.needs_input_grad[2]:
            grad_E = (A.T * (input.T @ grad_output @ B)).T
        if ctx.needs_input_grad[3]:
            grad_B = ((A) @ input.T @ grad_output).T

        return grad_input, grad_A, grad_E, grad_B


class SVDLinear(nn.Module, LoraLayer):
    adapter_layer_names = ()
    # Lora implemented in a dense layer
    def __init__(
        self,
        base_layer,
        adapter_name: str,
        r: int = 0,
        head: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,
        # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        is_target_conv_1d_layer: bool = False,
        init_lora_weights: Union[bool, str] = True,
        mode: str = None,
        **kwargs,
    ):
        assert(init_lora_weights is False)
        super().__init__()
        LoraLayer.__init__(self, base_layer, **kwargs)
        self.fan_in_fan_out = fan_in_fan_out
        self.mode = mode

        self.lora_A = nn.ParameterDict({})
        self.lora_B = nn.ParameterDict({})
        self.lora_E = nn.ParameterDict({})
        self.lora_active_A = nn.ParameterDict({})
        self.lora_active_B = nn.ParameterDict({})

        self.head = head
        assert (r % head == 0)
        self.block_size = r // head

        self._active_adapter = adapter_name
        self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
        self.is_target_conv_1d_layer = is_target_conv_1d_layer

        active_adapter = self.active_adapter[0]
        epsilon = 1e-9

        self.lora_active_A[active_adapter].register_hook(lambda grad: grad / (torch.abs(self.lora_E[active_adapter][:self.block_size, :]) + epsilon))
        self.lora_active_B[active_adapter].register_hook(lambda grad: grad / (torch.abs(self.lora_E[active_adapter][:self.block_size, :]) + epsilon).T)


    def update_mask(self):
        adapter = self.active_adapter[0]
        with torch.no_grad():
            if self.mode == "svd_init" or self.mode == "svd_shuffle":
                assert dist.is_initialized() is False
                new_active_A = self.lora_A[adapter][:self.block_size, :].data
                new_A = torch.cat([self.lora_A[adapter][self.block_size:, :].data, self.lora_active_A[adapter].data], dim=0)
                self.lora_active_A[adapter].data = new_active_A
                self.lora_A[adapter].data = new_A

                new_active_B = self.lora_B[adapter][:, :self.block_size].data
                new_B = torch.cat([self.lora_B[adapter][:, self.block_size:].data, self.lora_active_B[adapter].data],
                                  dim=1)
                self.lora_active_B[adapter].data = new_active_B
                self.lora_B[adapter].data = new_B

                self.lora_E[adapter].data = torch.cat([self.lora_E[adapter].data[self.block_size:, :], self.lora_E[adapter].data[:self.block_size, :]], dim=0)
            if self.mode == "svd_adaptive":
                p = torch.abs(self.lora_E[adapter].squeeze(1))
                p = p + torch.mean(p)
                p = p / p.sum()
                indices = torch.multinomial(p, num_samples=self.block_size, replacement=False)
                is_dist = dist.is_initialized()
                if is_dist:
                    dist.broadcast(indices, 0)
                full_A = torch.cat([self.lora_active_A[adapter].data, self.lora_A[adapter].data], dim=0)
                full_B = torch.cat([self.lora_active_B[adapter].data, self.lora_B[adapter].data], dim=1)
                mask = torch.zeros_like(p, dtype=torch.bool, device=self.lora_E[adapter].device)
                mask[indices] = True

                self.lora_active_A[adapter].data = full_A[mask]
                self.lora_A[adapter].data = full_A[~mask]
                self.lora_active_B[adapter].data = full_B[:, mask]
                self.lora_B[adapter].data = full_B[:, ~mask]
                self.lora_E[adapter].data = torch.cat([self.lora_E[adapter].data[mask], self.lora_E[adapter].data[~mask]], dim=0)
            else:
                raise Exception("invalid svd init in update_mask")


    def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
        if r <= 0:
            raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha
        if lora_dropout > 0.0:
            lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            lora_dropout_layer = nn.Identity()

        self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
        # Actual trainable parameters
        if r > 0:
            self.lora_A[adapter_name] = nn.Parameter(torch.randn(r - self.block_size, self.in_features), requires_grad=False)
            self.lora_E[adapter_name] = nn.Parameter(torch.randn(r, 1))
            self.lora_B[adapter_name] = nn.Parameter(torch.randn(self.out_features, r - self.block_size), requires_grad=False)
            self.lora_active_A[adapter_name] = nn.Parameter(torch.randn(self.block_size, self.in_features))
            self.lora_active_B[adapter_name] = nn.Parameter(torch.randn(self.out_features, self.block_size))
            self.scaling[adapter_name] = lora_alpha / r


        weight = getattr(self.get_base_layer(), "weight", None)
        if weight is not None:
            # the layer is already completely initialized, this is an update
            if weight.dtype.is_floating_point or weight.dtype.is_complex:
                self.to(weight.device, dtype=weight.dtype)
            else:
                self.to(weight.device)
        self.set_adapter(self.active_adapters)


    def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
        """
        Merge the active adapter weights into the base weights

        Args:
            safe_merge (`bool`, *optional*):
                If True, the merge operation will be performed in a copy of the original weights and check for NaNs
                before merging the weights. This is useful if you want to check if the merge operation will produce
                NaNs. Defaults to `False`.
            adapter_names (`List[str]`, *optional*):
                The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
                to `None`.
        """
        if self.merged:
            warnings.warn(
                f"Already following adapters were merged {','.join(self.merged_adapters)}. "
                f"You are now additionally merging {','.join(self.active_adapters)}."
            )

        if adapter_names is None:
            adapter_names = self.active_adapters

        for active_adapter in adapter_names:
            if active_adapter in self.lora_A.keys():
                base_layer = self.get_base_layer()
                base_layer.weight = torch.nn.parameter.Parameter(self.get_delta_weight(active_adapter)).to(self.lora_A[active_adapter].device)
                self.merged_adapters.append(active_adapter)

    def get_delta_weight(self, adapter):
        return (
            transpose(
                torch.cat([self.lora_active_B[adapter], self.lora_B[adapter]], dim=1) @ (torch.cat([self.lora_active_A[adapter], self.lora_A[adapter]], dim=0) * self.lora_E[adapter]),
                self.fan_in_fan_out,
            )
            * self.scaling[adapter]
        )

    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
        previous_dtype = x.dtype
        for active_adapter in self.active_adapters:
            if active_adapter not in self.lora_A.keys():
                continue
            # lora_A = self.lora_A[active_adapter]
            # lora_B = self.lora_B[active_adapter]
            # dropout = self.lora_dropout[active_adapter]
            # scaling = self.scaling[active_adapter]
            x = x.to(self.lora_A[active_adapter].dtype)
            # result += lora_B(lora_A(dropout(x))) * scaling
            # if x.dim() == 3:
            #     dim_0, dim_1, dim_2 = x.shape
            #     x = x.view(-1, dim_2)
            #     svd_result = SVDFunction.apply(self.lora_dropout[active_adapter](x), self.lora_A[active_adapter],self.lora_E[active_adapter], self.lora_B[active_adapter]) * self.scaling[active_adapter]
            #     svd_result = svd_result.view(dim_0, dim_1, -1)
            #     result += svd_result
            # else:
            #     result += SVDFunction.apply(self.lora_dropout[active_adapter](x), self.lora_A[active_adapter],self.lora_E[active_adapter], self.lora_B[active_adapter]) * self.scaling[active_adapter]
            result = self.lora_dropout[active_adapter](x) @ (torch.cat([self.lora_active_A[active_adapter], self.lora_A[active_adapter]], dim=0) * self.lora_E[active_adapter]).T @ torch.cat([self.lora_active_B[active_adapter], self.lora_B[active_adapter]], dim=1).T

        if self.base_layer.bias is not None:
            result += self.base_layer.bias

        result = result.to(previous_dtype)
        return result

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "lora." + rep

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

# def get_delta_weight(m):
#     return transpose(
#         m.lora_B[m.active_adapter].weight @ m.lora_A[m.active_adapter].weight,
#         False,
#     ) * m.scaling[m.active_adapter]

def print_lora(m, mode):
    active_adapter = m.active_adapter[0]
    if isinstance(m, Lora.Linear) or isinstance(m, Lora.Conv2d):
        print(f"{mode} lora_A: {torch.mean(torch.abs(m.lora_A[active_adapter].weight.data))}")
        print(f"{mode} lora_B: {torch.mean(torch.abs(m.lora_B[active_adapter].weight.data))}")
    elif isinstance(m, Lora.Embedding):
        print(f"{mode} lora_embedding_A: {torch.mean(torch.abs(m.lora_embedding_A[active_adapter].data))}")
        print(f"{mode} lora_embedding_B: {torch.mean(torch.abs(m.lora_embedding_B[active_adapter].data))}")
    elif isinstance(m, SVDLinear):
        print(f"{mode} lora_A: {torch.mean(torch.abs(m.lora_A[active_adapter]))}, top 5 lora_A: {torch.mean(torch.abs(m.lora_A[active_adapter][:5, :]), dim=1)}")
        print(f"{mode} lora_active_A: {torch.mean(torch.abs(m.lora_active_A[active_adapter]))}")
        print(f"{mode} lora_E: {torch.mean(torch.abs(m.lora_E[active_adapter]))}, min: {torch.min(m.lora_E[active_adapter])}, max: {torch.max(m.lora_E[active_adapter])}")
        print(f"{mode} lora_B: {torch.mean(torch.abs(m.lora_B[active_adapter]))}")
        print(f"{mode} lora_active_B: {torch.mean(torch.abs(m.lora_active_B[active_adapter]))}, min: {torch.min(m.lora_active_B[active_adapter])}, max: {torch.max(m.lora_active_B[active_adapter])}")
        # print(f"{mode} weight: {torch.mean(torch.abs(m.base_layer.weight))}")
        print(f"{mode} merged weight: {torch.mean(m.get_delta_weight(active_adapter))}")

def merge_refresh(model: torch.nn.Module, args, optimizer, last_lora_A, last_lora_B, key_list, iteration_num):
    t1 = time.time()
    for n, m in model.named_modules():
        # print(f"{n}: {m}, {type(m)}")
        if isinstance(m, LoraLayer):
            print(f"{n}")
            active_adapter = m.active_adapter[0]
            if isinstance(m, SVDLinear):
                # cur_head = (m.cur_head + 1) % args.round
                if iteration_num % args.round == 0:
                    print("resvd")
                    print_lora(m, "before")
                    init_layer(m, 1, args, last_lora_A[n], last_lora_B[n])
                else:
                    m.update_mask()
                print_lora(m, "after")
            else:
                m.weight.data += m.get_delta_weight(active_adapter)
                print_lora(m, "before")
                init_layer(m, 1, args, last_lora_A[n], last_lora_B[n])
                print_lora(m, "after")
                m.weight.data -= m.get_delta_weight(active_adapter)
            if isinstance(m, Lora.Linear):
                if not set(optimizer.state[m.lora_A[active_adapter].weight].keys()).issubset(set(key_list)):
                    print(optimizer.state[m.lora_A[active_adapter].weight].keys())
                    raise Exception("invalid optimizer key")
            for key in key_list:
                if isinstance(m, Lora.Linear) or isinstance(m, Lora.Conv2d):
                    parameters = [m.lora_A[active_adapter].weight, m.lora_B[active_adapter].weight]
                elif isinstance(m, Lora.Embedding):
                    parameters = [m.lora_embedding_A[active_adapter], m.lora_embedding_B[active_adapter]]
                elif isinstance(m, SVDLinear):
                    parameters = [m.lora_active_A[active_adapter], m.lora_active_B[active_adapter], m.lora_E[active_adapter]]
                else:
                    raise Exception("invalid type of loralayer")
                for parameter in parameters:
                    if isinstance(optimizer.state[parameter][key], int):
                        optimizer.state[parameter][key] = 0
                    else:
                        optimizer.state[parameter][key].zero_()

    print(f"merge_refresh cost: {time.time() - t1}s")

def init_layer(m, beta, args, last_lora_A=None, last_lora_B=None, clear=False, first=False):
    with torch.no_grad():
        if clear == True:
            nn.init.zeros_(m.weight)
        if isinstance(m, Lora.Linear) or isinstance(m, Lora.Conv2d):
            # assert dist.is_initialized() is False
            active_adapter = m.active_adapter[0]
            fan_in = m.in_features
            r = m.r[active_adapter]
            fan_out = m.out_features
            if (args.init in ["lora_half", "lora_momentum"] and first == True) or args.init == "lora":
                nn.init.kaiming_uniform_(m.lora_A[active_adapter].weight, a=math.sqrt(5))
                nn.init.zeros_(m.lora_B[active_adapter].weight)
            elif args.init == "lora_B":
                nn.init.zeros_(m.lora_A[active_adapter].weight)
                nn.init.kaiming_uniform_(m.lora_B[active_adapter].weight, a=math.sqrt(5))
            elif args.init == "lora_half":
                bound = 1 / math.sqrt(fan_in)
                value = torch.rand(r // 2, fan_in) * 2 * bound - bound
                m.lora_A[active_adapter].weight[r // 2:, :] = value
                m.lora_A[active_adapter].weight.data = torch.flip(m.lora_A[active_adapter].weight, [0])
                value = torch.zeros(fan_out, r // 2)
                m.lora_B[active_adapter].weight[:, r // 2:] = value
                m.lora_B[active_adapter].weight.data = torch.flip(m.lora_B[active_adapter].weight, [1])
            elif (args.init == "momentum" and first == True) or args.init == "random":
                weight_bound_A = 1 / math.sqrt(math.sqrt(fan_in * r * beta))
                weight_bound_B = 1 / math.sqrt(math.sqrt(fan_out * r * beta))
                torch.nn.init.uniform_(m.lora_A[active_adapter].weight, -weight_bound_A, weight_bound_A)
                torch.nn.init.uniform_(m.lora_B[active_adapter].weight, -weight_bound_B, weight_bound_B)
            elif args.init == "momentum":
                weight_bound_A = 1 / math.sqrt(math.sqrt(fan_in * r * beta))
                weight_bound_B = 1 / math.sqrt(math.sqrt(fan_out * r * beta))
                lora_A_init = m.lora_A[active_adapter].weight.data - last_lora_A
                variance = lora_A_init.pow(2).sum()
                lora_A_init = torch.clamp(
                    lora_A_init * torch.rsqrt(variance + 1e-8), -weight_bound_A,
                    weight_bound_A)
                m.lora_A[active_adapter].weight.data = lora_A_init
                lora_B_init = m.lora_B[active_adapter].weight.data - last_lora_B
                variance = lora_B_init.pow(2).sum()
                lora_B_init = torch.clamp(
                    lora_B_init * torch.rsqrt(variance + 1e-8), -weight_bound_B,
                    weight_bound_B)
                m.lora_B[active_adapter].weight.data = lora_B_init
            elif args.init == "lora_momentum":
                weight_bound_A = 1 / math.sqrt(fan_in * beta)
                lora_A_init = m.lora_A[active_adapter].weight.data - last_lora_A
                variance = lora_A_init.pow(2).sum()
                lora_A_init = torch.clamp(
                    lora_A_init * torch.rsqrt(variance + 1e-8), -weight_bound_A,
                    weight_bound_A)
                m.lora_A[active_adapter].weight.data = lora_A_init
                nn.init.zeros_(m.lora_B[active_adapter].weight)
            else:
                raise Exception("invalid init")
        elif isinstance(m, SVDLinear):
            active_adapter = m.active_adapter[0]
            if first:
                weight = m.base_layer.weight
            else:
                weight = m.get_delta_weight(active_adapter)
            r = m.r[active_adapter]
            U, S, Vh = torch.linalg.svd(weight.to(torch.float32), full_matrices=False)
            U = U.contiguous()
            S = S.contiguous()
            Vh = Vh.contiguous()
            is_dist = dist.is_initialized()
            if is_dist:
                dist.broadcast(U, 0)
                dist.broadcast(S, 0)
                dist.broadcast(Vh, 0)

            if args.init == "svd_shuffle":
                rand_index = torch.randperm(r)
                U = U[:, rand_index]
                Vh = Vh[rand_index, :]
                S = S[rand_index]
            m.lora_active_B[active_adapter][:, :] = U[:, :m.block_size]
            m.lora_B[active_adapter][:, :] = U[:, m.block_size:r]
            m.lora_active_A[active_adapter][:, :] = Vh[:m.block_size, :]
            m.lora_A[active_adapter][:, :] = Vh[m.block_size:r, :]
            m.lora_E[active_adapter][:, :] = S[:r].unsqueeze(1)
            if args.init == "svd_adaptive":
                m.update_mask()
            # if first:
            #     m.weight.data -= m.get_delta_weight(active_adapter)

        elif isinstance(m, Lora.Embedding):
            assert dist.is_initialized() is False
            active_adapter = m.active_adapter[0]
            nn.init.zeros_(m.lora_embedding_A[active_adapter])
            nn.init.kaiming_uniform_(m.lora_embedding_B[active_adapter], mode="fan_out", a=math.sqrt(5))
        else:
            raise Exception("invalid type of loralayer")


def get_visible_devices():
    cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
    if cuda_visible_devices is not None:
        return list(map(int, cuda_visible_devices.split(',')))
    else:
        return list(range(pynvml.nvmlDeviceGetCount()))

def get_free_memory():
    visible_devices = get_visible_devices()
    print(f"visible_devices: {visible_devices}")
    memory_info = []
    for virtual_index, real_index in enumerate(visible_devices):
        handle = pynvml.nvmlDeviceGetHandleByIndex(real_index)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        free_memory = info.free
        memory_info.append((virtual_index, free_memory))
    return memory_info


def select_device_with_most_free_memory():
    pynvml.nvmlInit()
    free_memory = get_free_memory()
    device_with_max_memory = max(free_memory, key=lambda x: x[1])
    return device_with_max_memory[0]


class CustomConstantLengthDataset(IterableDataset):
    """
    Iterable dataset that returns constant length chunks of tokens from stream of text files.
    The dataset also formats the text before tokenization with a specific format that is provided
    by the user.

        Args:
            tokenizer (`transformers.PreTrainedTokenizer`):
                The processor used for processing the data.
            dataset (`dataset.Dataset`):
                Dataset with text files.
            dataset_text_field (`str`, **optional**):
                Name of the field in the dataset that contains the text. Used only if `formatting_func` is `None`.
            formatting_func (`Callable`, **optional**):
                Function that formats the text before tokenization. Usually it is recommended to have follows a certain
                pattern such as `"### Question: {question}\n ### Answer: {answer}\n"`
            infinite (`bool`, *optional*, defaults to `False`):
                If True the iterator is reset after dataset reaches end else stops.
            seq_length (`int`, *optional*, defaults to `1024`):
                Length of token sequences to return.
            num_of_sequences (`int`, *optional*, defaults to `1024`):
                Number of token sequences to keep in buffer.
            chars_per_token (`int`, *optional*, defaults to `3.6`):
                Number of characters per token used to estimate number of tokens in text buffer.
            eos_token_id (`int`, *optional*, defaults to `0`):
                Id of the end of sequence token if the passed tokenizer does not have an EOS token.
            shuffle ('bool', *optional*, defaults to True)
                Shuffle the examples before they are returned
    """

    def __init__(
        self,
        tokenizer,
        dataset,
        dataset_text_field=None,
        formatting_func=None,
        infinite=False,
        seq_length=1024,
        num_of_sequences=1024,
        chars_per_token=3.6,
        eos_token_id=0,
        shuffle=True,
    ):
        self.tokenizer = tokenizer

        if tokenizer.eos_token_id is None:
            warnings.warn(
                "The passed tokenizer does not have an EOS token. We will use the passed eos_token_id instead which corresponds"
                f" to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id."
            )

        self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id
        self.dataset = dataset
        self.seq_length = seq_length
        self.infinite = infinite
        self.current_size = 0
        self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
        self.shuffle = shuffle
        if formatting_func is None:
            self.formatting_func = lambda x: x[dataset_text_field]
        else:
            self.formatting_func = formatting_func

        self.current_epoch = 0
        if formatting_func is not None:
            formatting_func_signature = formatting_func.__code__.co_varnames
            if len(formatting_func_signature) > 1:
                warnings.warn(
                    "The passed formatting_func has more than one argument. Usually that function should have a single argument `example`"
                    " which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."
                )

    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        iterator = iter(self.dataset)
        more_examples = True
        while more_examples:
            buffer, buffer_len = [], 0
            while True:
                if buffer_len >= self.max_buffer_size:
                    break
                try:
                    buffer.append(self.formatting_func(next(iterator)))
                    buffer_len += len(buffer[-1])
                except StopIteration:
                    if self.infinite:
                        self.current_epoch += 1
                        iterator = iter(self.dataset)
                        warnings.warn("The dataset reached end and the iterator is reset to the start.")
                    else:
                        more_examples = False
                        break
            tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
            all_token_ids = []
            for tokenized_input in tokenized_inputs:
                all_token_ids.extend(tokenized_input + [self.concat_token_id])
            examples = []
            for i in range(0, len(all_token_ids), self.seq_length):
                input_ids = all_token_ids[i : i + self.seq_length]
                if len(input_ids) == self.seq_length:
                    examples.append(input_ids)
            if self.shuffle:
                random.shuffle(examples)
            for example in examples:
                self.current_size += 1
                yield {
                    "input_ids": torch.LongTensor(example),
                    "labels": torch.LongTensor(example),
                }


class MySFTTrainer(SFTTrainer):
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        print(f"Saving model checkpoint to {output_dir}")

        supported_classes = (PreTrainedModel, PeftModel)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, supported_classes):
            raise Exception("invalid model type")
            # if state_dict is None:
            #     state_dict = self.model.state_dict()
            #
            # if isinstance(unwrap_model(self.model), supported_classes):
            #     if isinstance(unwrap_model(self.model), PeftModel):
            #         print("save all peft model")
            #         unwrap_model(self.model).base_model.model.save_pretrained(
            #             output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            #         )
            #     else:
            #         print("save full parameter model")
            #         unwrap_model(self.model).save_pretrained(
            #             output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            #         )
            # else:
            #     print("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            #     if self.args.save_safetensors:
            #         safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
            #     else:
            #         torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            if state_dict is None:
                state_dict = self.model.state_dict()

            if isinstance(self.model, PeftModel):
                print("save all peft model")
                torch.save(state_dict, os.path.join(output_dir, "all_model.pt"))
                # self.model.base_model.model.save_pretrained(
                #     output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                # )
            else:
                print("save full parameter model")
                torch.save(state_dict, os.path.join(output_dir, "all_model.pt"))
            # self.model.save_pretrained(
            #     output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            # )

        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        pass


class onebitlowranklinear(nn.Module):

    def __init__(self, in_features, out_features, weight, args, device, dtype):
        super().__init__()
        self.args = args
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(weight.detach().clone().to(device=device, dtype=dtype), requires_grad=True)
        self.onebit_weight = torch.zeros_like(self.latent, requires_grad=False, device=device, dtype=dtype)

        self.r = self.args.rank
        if self.r <= 0:
            raise ValueError("r must be positive.")
        self.lora_A = nn.Parameter(torch.empty(self.r, in_features, dtype=dtype, device=device))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        self.lora_B = nn.Parameter(torch.empty(out_features, self.r, dtype=dtype, device=device))
        nn.init.zeros_(self.lora_B)
        self._update_weight()
        self.scaling = self.args.lora_alpha / self.r

    def extra_repr(self) -> str:
        return (f'in_features={self.in_features}, out_features={self.out_features}, rank={self.r}, bias={self.bias is not None}')

    def forward(self, input):
        """
        Reparameterized sparse and low rank linear layer
                    x W_a @ W_b * lora_alpha / r + x W_sp + bias
        Notice that scale = lora_alpha / r.
        Notice that this class cannot be wrapped to linear layer and thus cannot be used for fine-tune
        """
        W = self.onebit_weight + (self.weight - self.weight.detach())
        result = input @ W.T + input @ self.lora_A.T @ self.lora_B.T  * self.scaling
        if self.bias is not None:
            result += self.bias
        
        return torch.matmul(x, W.T)
    
    def _update_weight(self):
        thre = torch.median(self.latent)
        self.onebit_weight[self.latent.data > thre] = 1
        self.onebit_weight[self.latent.data == thre] = 0
        self.onebit_weight[self.latent.data < thre] = -1
        
    def _update_device(self):
        self.onebit_weight = self.onebit_weight.to(self.latent.device)

class sllinear(nn.Module):
    def __init__(self, in_features, out_features, weight, args, device, dtype, bias=True):
        super().__init__()
        self.args = args
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(weight.detach().clone().to(device=device, dtype=dtype), requires_grad=True)

        self.r = self.args.rank
        if self.r <= 0:
            raise ValueError("r must be positive.")
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features, device=device, dtype=dtype, requires_grad=True))
            a = 1/math.sqrt(out_features)
            nn.init.uniform_(self.bias, -a, a)
        else:
            self.register_parameter('bias', None)
            
        self.lora_A = nn.Parameter(torch.empty(self.r, in_features, dtype=dtype, device=device))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        self.lora_B = nn.Parameter(torch.empty(out_features, self.r, dtype=dtype, device=device))
        nn.init.zeros_(self.lora_B)
        # self.reset_parameters()
        self.scaling = self.args.lora_alpha / self.r

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Reparameterized sparse and low rank linear layer
                    x W_a @ W_b * lora_alpha / r + x W_sp + bias
        Notice that scale = lora_alpha / r.
        Notice that this class cannot be wrapped to linear layer and thus cannot be used for fine-tune
        """
        if self.r > 0:
            result = input @ self.weight.T + input @ self.lora_A.T @ self.lora_B.T  * self.scaling
            if self.bias is not None:
                result += self.bias
        else:
            result = F.linear(input, self.weight, self.bias)
        return result
        
    
    def extra_repr(self) -> str:
        return (f'in_features={self.in_features}, out_features={self.out_features}, rank={self.r}, '
                f'sparsity={self.args.sparsity}, bias={self.bias is not None}')

target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
def build_slmodel(model, args):
    # replace nn.linear layers with sllinear layer
    
    for name, module in list(model.named_modules()):
        if isinstance(module, nn.Linear) and name.split(".")[-1] in target_modules:
            # get the input and output features
            in_features = module.in_features
            out_features = module.out_features
            # get the device and dtype of the original layer
            device = module.weight.device
            dtype = module.weight.dtype
            weight = module.weight
            bias = module.bias is not None
            # create new sllinear layer

            if args.sltrain:
                new_module = sllinear(in_features, out_features, weight, args, device, dtype, bias)
            elif args.onebitlowranktrain:
                new_module = onebitlowranklinear(in_features, out_features, weight, args, device, dtype)
            else:
                raise NotImplementedError("only sltrain and onebitlowranktrain are supported")
            
            # replace the module properly (supporting nested modules)
            parent_name = ".".join(name.split(".")[:-1])
            child_name = name.split(".")[-1]
            parent_module = model.get_submodule(parent_name) if parent_name else model
            setattr(parent_module, child_name, new_module)
    return model


import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn import functional as F

def interp1_nearest(xs_sorted, ys_sorted, x):
    """
    PyTorch equivalent of MATLAB's interp1 with 'nearest' and 'extrap'
    
    Args:
        xs_sorted: sorted x coordinates (1D tensor)
        ys_sorted: corresponding y values (1D tensor) 
        x: query points (tensor of any shape)
    
    Returns:
        interpolated values with same shape as x
    """
    # Find insertion points
    idx = torch.searchsorted(xs_sorted, x, right=True)
    
    # Handle boundary cases for extrapolation
    idx = torch.clamp(idx, 0, len(xs_sorted) - 1)
    
    # For points exactly on grid, searchsorted gives the next index
    # For nearest neighbor, we need to check which is closer
    mask = idx > 0
    left_dist = torch.where(mask, torch.abs(x - xs_sorted[idx - 1]), float('inf'))
    right_dist = torch.abs(x - xs_sorted[idx])
    
    # Use left point if it's closer
    use_left = (left_dist < right_dist) & mask
    idx = torch.where(use_left, idx - 1, idx)
    
    return ys_sorted[idx]


def compute_gllf_gradient(x,xmin,xmax,yL,yR,mu,I,mode):
    eps=1e-6
    mu_safe = mu.clamp(min=eps, max=1-eps)  # 避免 mu=0 或 mu=1
    x_denom = (x - xmin).clamp(min=eps)     # 避免除零
    xmax_safe = xmax + eps

    z = (1/mu_safe - 2) * (I - (x - xmin) / (xmax - xmin + eps))
    exp_z = torch.exp(torch.clamp(z, -50, 50))  # 避免溢出

    denom = ((exp_z * (x - xmax_safe)) / x_denom - 1)
    denom = denom**2 + eps  # 避免除零

    if mode == 'x':
        grad_x = ((yL - yR) *
                  ((exp_z * (x - xmax_safe)) / (x_denom**2) -
                   exp_z / x_denom +
                   exp_z * (1/mu_safe - 2) * (x - xmax_safe) / (x_denom * (xmax - xmin + eps)))
                  ) / denom
        return grad_x

    elif mode == 'mu':
        grad_mu = (exp_z * (x - xmax_safe) * (yL - yR) *
                   (I - (x - xmin) / (xmax - xmin + eps))) / (
                       mu_safe**2 * denom * x_denom)
        return grad_mu

    elif mode == 'I':
        grad_I = -(exp_z * (1/mu_safe - 2) * (x - xmax_safe) * (yL - yR)) / (
            denom * x_denom)
        return grad_I



class compute_gllf_neuronwise(Function):
    '''
    Each neuron learns a unique CM-GLLF curve
    '''
    @staticmethod
    def forward(ctx, x, mu, I, rectify, mode, yL=None, yR=None):
        '''
        x: for MLP (batch_size,num_neuron)
           for CNN (batch_size,num_channel,height,width)
        mu: 1D tensor of size (num_neuron,) or (1,num_channel,1,1)
        I: 1D tensor of size (num_neuron,) or (1,num_channel,1,1)
        rectify(bool): if y values less than 0 should be set to 0
        mode: 3 modes for CM-GLLF and 1 for ReLU
            logistic: search the range of 0<Mu<=0.5
            logit: search the range of 0.5=<Mu<1
            all: mu not limited
        '''
        assert not torch.isnan(x).any(), "x has nan"
        xmax = x.abs().max()
        xmin = -xmax
        eps = 1e-6
        if yL == None:
            yL = xmin
        if yR == None:
            yR = xmax     
        # calculate the output value
        # For logistic phase (0<=mu<=1)
        if mode in ['logistic','all']:
            y_logistic = yL + (yR-yL) / (1 + (xmax - x) / (x - xmin + eps)*torch.exp(torch.clamp((2-1/mu)*((x-xmin)/(xmax-xmin)-I), -50, 50)))
        if mode in ['logit','all']:
            # in logit phase - using more efficient interp1_nearest
            y_logit = torch.zeros(x.shape, device=x.device)
            # Process each neuron separately for neuronwise parameters
            if mu.dim() > 0 and mu.numel() > 1:  # neuronwise case
                for i in range(mu.numel()):
                    # Get current neuron's parameters
                    mu_i = mu.flatten()[i] if mu.dim() > 0 else mu
                    I_i = I.flatten()[i] if I.dim() > 0 else I
                    
                    # Generate inverse lookup table for this neuron
                    ys = torch.linspace(yL, yR, 1000, device=x.device)
                    xs = xmin + (xmax-xmin) / (1 + (yR-ys)/(ys-yL)*torch.exp((2-1/(1-mu_i))*((ys-yL)/(yR-yL)-I_i)))
                    
                    # Extract data for current neuron
                    if x.dim() == 2:  # MLP case
                        x_neuron = x[:, i]
                        y_logit[:, i] = interp1_nearest(xs, ys, x_neuron)
                    elif x.dim() == 4:  # CNN case
                        x_neuron = x[:, i, :, :]
                        y_logit[:, i, :, :] = interp1_nearest(xs, ys, x_neuron)
                    else:
                        raise ValueError('Shape of x is neither 2 nor 4!')
            else:  # single parameter case
                # Generate inverse lookup table
                ys = torch.linspace(yL, yR, 1000, device=x.device)
                xs = xmin + (xmax-xmin) / (1 + (yR-ys)/(ys-yL)*torch.exp((2-1/(1-mu))*((ys-yL)/(yR-yL)-I)))
                y_logit = interp1_nearest(xs, ys, x)
        if mode == 'all':
            # assign CM-GLLF values based on the cases
            y = torch.where((mu >= 0) & (mu <= 0.5),y_logistic,y_logit)   
        elif mode == 'logistic':
            y = y_logistic
        elif mode == 'logit':
            # this makes logit phase include ReLU
            y = y_logit
        else:
            raise ValueError("Mode should be in [all, logistic, logit]")
        assert not torch.isnan(y).any(), "y has nan"
        # save the data for backward pass
        ctx.save_for_backward(x, y, mu, I,)
        ctx.scalars = (xmin, xmax, yL, yR, rectify, mode)
        return torch.clamp(y, min=0) if rectify else y  
    @staticmethod
    def backward(ctx, grad_output):
        x, y, mu, I = ctx.saved_tensors
        xmin, xmax, yL, yR, rectify, mode = ctx.scalars
        # check if mu and I needs gradient
        mu_need_grad = ctx.needs_input_grad[1]
        I_need_grad = ctx.needs_input_grad[2]
        if mode in ['logistic','all']:
            # for gradients of logistic phase (see ./matlab_utils/get_gllf_drv)
            grad_x_logistic = compute_gllf_gradient(x,xmin,xmax,yL,yR,mu,I,mode='x')
            assert not torch.isnan(grad_x_logistic).any(), "grad_x_logistic nan"
            if mu_need_grad:
                grad_mu_logistic = compute_gllf_gradient(x,xmin,xmax,yL,yR,mu,I,mode='mu')
            if I_need_grad:
                grad_I_logistic = compute_gllf_gradient(x,xmin,xmax,yL,yR,mu,I,mode='I')
        if mode in ['logit','all']:
            # for gradients of logit phase
            # derivative of inverse function is 1/f'(f^(-1)(x))
            inv_grad_x_logit = compute_gllf_gradient(y,xmin,xmax,yL,yR,1-mu,I,mode='x')
            grad_x_logit = 1.0 / inv_grad_x_logit
            # derivative of logit function respect to mu and I
            if mu_need_grad:
                grad_mu_logit = (-1) * compute_gllf_gradient(y,xmin,xmax,yL,yR,1-mu,I,mode='mu') / inv_grad_x_logit
            if I_need_grad:
                grad_I_logit = (-1) * compute_gllf_gradient(y,xmin,xmax,yL,yR,1-mu,I,mode='I') / inv_grad_x_logit
        # merge the gradients, assign based on logistic and logit
        if mode == 'all':
            grad_x = torch.where((mu >= 0) & (mu <= 0.5), grad_x_logistic, grad_x_logit)
            if ctx.needs_input_grad[1]:  
                grad_mu = torch.where((mu >= 0) & (mu <= 0.5), grad_mu_logistic, grad_mu_logit)
            else:
                grad_mu = None
            if ctx.needs_input_grad[2]:  
                grad_I = torch.where((mu >= 0) & (mu <= 0.5), grad_I_logistic, grad_I_logit)
            else:
                grad_I = None         
        elif mode == 'logistic':
            grad_x = grad_x_logistic
            if ctx.needs_input_grad[1]:  
                grad_mu = grad_mu_logistic
            else:
                grad_mu = None
            if ctx.needs_input_grad[2]:  
                grad_I = grad_I_logistic
            else:
                grad_I = None                    
        elif mode == 'logit':
            grad_x = grad_x_logit
            if ctx.needs_input_grad[1]:  
                grad_mu = grad_mu_logit
            else:
                grad_mu = None
            if ctx.needs_input_grad[2]:  
                grad_I = grad_I_logit
            else:
                grad_I = None                     
        # if rectify, do:
        if rectify:
            grad_x[y<=0] = 0
            if ctx.needs_input_grad[1]:
                grad_mu[y<=0] = 0
            if ctx.needs_input_grad[2]:
                grad_I[y<=0] = 0

        assert not torch.isnan(grad_x).any(), "grad_x nan"
        assert not torch.isnan(grad_output).any(), "grad_output nan"
        assert not torch.isnan(grad_mu).any() if ctx.needs_input_grad[1] else True, "grad_mu nan"
        assert not torch.isnan(grad_I).any() if ctx.needs_input_grad[2] else True, "grad_I nan"
        # chain rule
        drv_x = drv_mu = drv_I = None
        if ctx.needs_input_grad[0]:
            drv_x = grad_output * grad_x
        if ctx.needs_input_grad[1]:
            if x.dim()==2:
                drv_mu = (grad_output * grad_mu).sum(dim=0)
            elif x.dim()==4:
                drv_mu = torch.sum(grad_output * grad_mu, dim=(0, 2, 3), keepdim=True)
        if ctx.needs_input_grad[2]:
            if x.dim()==2:
                drv_I = (grad_output * grad_I).sum(dim=0)
            elif x.dim()==4:
                drv_I = torch.sum(grad_output * grad_I, dim=(0, 2, 3), keepdim=True)     
        return drv_x, drv_mu, drv_I, None, None, None, None

class LearnableSigmoid(nn.Module):
    def __init__(self, init_alpha=1.0):
        super().__init__()
        # alpha 是可学习参数
        self.alpha = nn.Parameter(torch.tensor(init_alpha))

    def forward(self, x):
        alpha = torch.clamp(self.alpha, min=0.01, max=10.0)
        return torch.sigmoid(alpha * x)



class NeuroGliaLayer(nn.Module):
    def __init__(self, in_features, out_features, weight, args, device, dtype, bias=True):
        super().__init__()
        self.args = args
        self.in_features = in_features
        self.out_features = out_features
        if args.nonlinear_function=="sigmoid":
            self.expert_nonlinear=torch.nn.Sigmoid()
        elif args.nonlinear_function=="softmax":
            self.expert_nonlinear=torch.nn.Softmax(dim=-1)
        elif args.nonlinear_function=="learnable_sigmoid":
            self.expert_nonlinear=LearnableSigmoid(init_alpha=1.0)

        if args.use_output_cmgllf:
            self.I = nn.Parameter(torch.tensor(0.5), requires_grad=True)
            self.mu = nn.Parameter(torch.tensor(0.5), requires_grad=True)
            self.rectify = False
        if args.use_input_cmgllf:
            self.I = nn.Parameter(torch.full((in_features,), 0.5), requires_grad=True)
            self.mu = nn.Parameter(torch.full((in_features,), 0.5), requires_grad=True)
            self.rectify = False
            
        self.weight = nn.Parameter(weight.detach().clone().to(device=device, dtype=dtype), requires_grad=True)


        if self.args.not_contextual:
            self.glia = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        else:
            self.hidden_size = args.hidden_size
            if self.args.channel_wise:
                self.glia_channel_wise = nn.Sequential(
                    nn.Linear(self.in_features, self.hidden_size),
                    nn.SiLU(),
                    nn.Linear(self.hidden_size, out_features)
                )
                nn.init.kaiming_uniform_(self.glia_channel_wise[0].weight, a=math.sqrt(5))
                nn.init.kaiming_uniform_(self.glia_channel_wise[2].weight, a=math.sqrt(5))
            if self.args.scalar_wise:
                self.glia_scalar_wise = nn.Sequential(
                    nn.Linear(self.in_features, self.hidden_size),
                    nn.Sigmoid(),
                    nn.Linear(self.hidden_size, 1)
                )
                # initialize the weights in the glia module
                nn.init.kaiming_uniform_(self.glia_scalar_wise[0].weight, a=math.sqrt(5))
                nn.init.kaiming_uniform_(self.glia_scalar_wise[2].weight, a=math.sqrt(5))
        # exit()
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features, device=device, dtype=dtype, requires_grad=True))
            a = 1/math.sqrt(out_features)
            nn.init.uniform_(self.bias, -a, a)
        else:
            self.register_parameter('bias', None)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        #TODO
        if self.args.use_input_cmgllf:
            input = compute_gllf_neuronwise.apply(input, self.mu, self.I, self.rectify, 'logistic')
        if self.args.not_contextual:
            glia_weight = self.glia
        else:
            if self.args.channel_wise:
                self.glia_weight_channel_wise = 2 * self.expert_nonlinear(self.glia_channel_wise(input)) # (batch_size, out_features)

                # print(glia_weight.shape)
            if self.args.scalar_wise:
                self.glia_weight_scalar_wise = 2 * self.expert_nonlinear(self.glia_scalar_wise(input)) # (batch_size, 1)

            # glia_weight = self.glia(input) # (batch_size, 1)
            # print(glia_weight.shape)


        # if self.args.use_output_cmgllf:
        #     glia_weight = compute_gllf_neuronwise.apply(
        #             glia_weight, self.mu, self.I, self.rectify, 'logistic',
        #             torch.tensor(0.0, device=glia_weight.device),
        #             torch.tensor(1.0, device=glia_weight.device)
        #         )
        # else:
        #     glia_weight = 2 * self.expert_nonlinear(glia_weight)

        output = F.linear(input, self.weight, self.bias)
        if  self.args.channel_wise:
            output = output * self.glia_weight_channel_wise
        if  self.args.scalar_wise:
            output = output * self.glia_weight_scalar_wise

        return output
        
    
    def extra_repr(self) -> str:
        #TODO
        return

def build_neuro_glia_network(model, args):
    # replace nn.linear layers with neuro-glia layers

    for name, module in list(model.named_modules()):
        if isinstance(module, nn.Linear) and name.split(".")[-1] in target_modules:
            print(name)
            # get the input and output features
            in_features = module.in_features
            out_features = module.out_features
            # get the device and dtype of the original layer
            device = module.weight.device
            dtype = module.weight.dtype
            weight = module.weight
            bias = module.bias is not None
            # create new sllinear layer
            new_module = NeuroGliaLayer(in_features, out_features, weight, args, device, dtype, bias)
            
            # replace the module properly (supporting nested modules)
            parent_name = ".".join(name.split(".")[:-1])
            child_name = name.split(".")[-1]
            parent_module = model.get_submodule(parent_name) if parent_name else model
            setattr(parent_module, child_name, new_module)
    return model