from peft import LoraConfig
from peft.mapping import get_peft_model_state_dict
from peft.utils import PeftType
from dataclasses import dataclass, field
from typing import Optional, Union, List, Any
import torch.nn as nn
from enum import Enum


# Register sine LoRA as a new PEFT type
class SinePeftType(str, Enum):
    SINE_LORA = "SINE_LORA"


@dataclass
class sinLoraConfig(LoraConfig):
    """
    Sine based Parameter-Efficient configuration class extending standard LoRA.
    Implements sin(ωAB^T) transformation for bounded parameter-efficient unlearning.
    """
    s: float = field(default=1.0, metadata={"help": "Scale parameter for sine transformation"})
    freq: int = field(default=1, metadata={"help": "Frequency parameter ω for sine transformation"})
    peft_type: str = field(default="SINE_LORA", init=False)

    def __post_init__(self):
        super().__post_init__()
        # Ensure we're using sine LoRA type
        self.peft_type = "SINE_LORA"


@dataclass
class sinDoraConfig(sinLoraConfig):
    """
    SineDora configuration class extending Sine based Parameter-Efficient approach.
    """
    pass


def create_sine_lora_model(model: nn.Module, config: sinLoraConfig):
    """
    Custom function to create sine LoRA model implementing sin(ωAB^T).
    """
    from sine_lora_simple import apply_sine_lora_to_model

    print(f"Creating Sine LoRA model with sin({config.freq} * AB^T) / {config.s}")
    return apply_sine_lora_to_model(model, config)