import abc
import os
import torch
import torch.nn as nn
from torch.func import jvp, make_functional_with_buffers
import torch
import torch.nn as nn
from peft import LoraLayer
import utils
import copy

class LinearizedModel(nn.Module):
    """Creates a linearized version of a nn.Module.

    The linearized version of a model is a proper PyTorch model and can be
    trained as any other nn.Module.

    Args:
        model (nn.Module): The model to linearize. The trainable parameters of
            the linearized model will be initialized to the parameters of this
            model.
        init_model (nn.Module): A model of the same type as `model` containing
            the parameters around which the model is initialized. If not
            provided, `model` is used as the initialization model.
    """

    def __init__(self, model: nn.Module, init_model: nn.Module = None) -> None:
        """Initializes the linearized model."""
        super().__init__()
        if init_model is None:
            init_model = model
        assert not hasattr(self, "params0")
        # func0, params0, self.buffers0 = make_functional_with_buffers(
        #     init_model.eval(), disable_autograd_tracking=True
        # )
        self.func0 = lambda params, x : torch.func.functional_call(init_model.eval(), params, x)
        # _, params, _ = make_functional_with_buffers(
        #     model, disable_autograd_tracking=True
        # )
        self.func = lambda params, x : torch.func.functional_call(model, params, x)
        self.params = nn.ParameterList(model.parameters())
        self.params0 = nn.ParameterList(init_model.parameters())
        self._model_name = model.__class__.__name__

        # The intial parameters are not trainable.
        for p in self.params0:
            p.requires_grad = False

        # The params are.
        for p in self.params:
            p.requires_grad = True

    def __call__(self, x) -> torch.Tensor:
        """Computes the linearized model output using a first-order Taylor decomposition."""
        dparams = [p - p0 for p, p0 in zip(self.params, self.params0)]
        # 
        out, dp = jvp(
            lambda param: self.func0(param, x),
            (tuple(self.params0),),
            (tuple(dparams),),
        )
        return out + dp


def linearize_lora_model(model: nn.Module):

    for key, module in model.named_modules():
        if isinstance(module, LoraLayer) and isinstance(module, nn.Linear):
            print(f">>> convert {key} to linearized lora layer")

            parent, target, target_name = utils.rgetattr(model, key)
            setattr(parent, target_name, LinearizedModel(target))
        
    return model