import torch
import torch.nn as nn
from .config import SAEConfig
from .model import VanillaSAE, TopKSAE, BatchTopKSAE, JumpReLUSAE
from safetensors.torch import load_model
from huggingface_hub import hf_hub_download


class LoRALayer(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.1,
    ):
        super().__init__()
        self.r = r
        self.lora_alpha = lora_alpha
        
        # LoRA components
        self.lora_down = nn.Linear(in_features, r, bias=False)
        self.lora_up = nn.Linear(r, out_features, bias=False)
        self.dropout = nn.Dropout(p=lora_dropout)
        self.scaling = lora_alpha / r
        
        # Initialize weights
        nn.init.kaiming_uniform_(self.lora_down.weight)
        nn.init.zeros_(self.lora_up.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.lora_up(self.dropout(self.lora_down(x))) * self.scaling


class SAELoRAWrapper(nn.Module):
    """Wrapper for adding LoRA to a pre-trained SAE"""
    
    @classmethod
    def from_pretrained(
        cls,
        pretrained_id: str,
        r: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.0,
        target_modules: list[str] = ["W_enc", "W_dec"],
        revision: str = "main",
        device: str = "cuda",
    ):
        # Load config
        config = SAEConfig.from_pretrained(pretrained_id, revision=revision)
        
        # Load weights
        weights_path = hf_hub_download(
            repo_id=pretrained_id,
            filename="model.safetensors",
            revision=revision
        )
        
        # Initialize SAE with config
        sae_class = {
            "vanilla": VanillaSAE,
            "topk": TopKSAE,
            "batchtopk": BatchTopKSAE,
            "jumprelu": JumpReLUSAE,
        }[config.sae_type]
        
        base_sae = sae_class(config)
        load_model(base_sae, weights_path)
        base_sae.to(device)
        
        # Create LoRA wrapper
        wrapper = cls(
            base_sae=base_sae,
            r=r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=target_modules
        )
        
        return wrapper
    
    def __init__(
        self,
        base_sae: nn.Module,
        r: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.1,
        target_modules: list[str] = ["W_enc", "W_dec"],
    ):
        super().__init__()
        self.base_sae = base_sae
        self.config = base_sae.config
        
        # Freeze base model parameters
        for param in base_sae.parameters():
            param.requires_grad = False
        
        # Create LoRA layers
        if "W_enc" in target_modules:
            self.lora_enc = LoRALayer(
                self.config.act_size,
                self.config.dict_size,
                r=r,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
            )
        else:
            self.lora_enc = None
            
        if "W_dec" in target_modules:
            self.lora_dec = LoRALayer(
                self.config.dict_size,
                self.config.act_size,
                r=r,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
            )
        else:
            self.lora_dec = None
    
    def parameters(self):
        """Return only LoRA parameters for optimization"""
        params = []
        if self.lora_enc is not None:
            params.extend(self.lora_enc.parameters())
        if self.lora_dec is not None:
            params.extend(self.lora_dec.parameters())
        return params

    def forward(self, x: torch.Tensor, use_pre_enc_bias: bool = False) -> dict:
        # Get base activations
        x, x_mean, x_std = self.base_sae.preprocess_input(x)
        
        # Center x only if use_pre_enc_bias is True
        if use_pre_enc_bias:
            x_cent = x - self.base_sae.b_dec
        else:
            x_cent = x
        
        # Encoder with LoRA
        base_enc = x_cent @ self.base_sae.W_enc
        if self.lora_enc is not None:
            lora_enc = self.lora_enc(x_cent)
            acts = torch.relu(base_enc + lora_enc + self.base_sae.b_enc)
        else:
            acts = torch.relu(base_enc + self.base_sae.b_enc)
        
        # Apply sparsity (if using TopK variants)
        
        if isinstance(self.base_sae, TopKSAE):
            acts_topk = torch.topk(acts, self.base_sae.config.topk2, dim=-1)
            acts = torch.zeros_like(acts).scatter(-1, acts_topk.indices, acts_topk.values)
        elif isinstance(self.base_sae, BatchTopKSAE):
            acts_topk = torch.topk(acts.flatten(), self.config.topk2 * x.shape[0], dim=-1)
            acts_topk = (
                torch.zeros_like(acts.flatten())
                .scatter(-1, acts_topk.indices, acts_topk.values)
                .reshape(acts.shape)
            )
        elif isinstance(self.base_sae, JumpReLUSAE):
            acts = self.base_sae.jump_relu(acts)
        
        # Decoder with LoRA
        base_dec = acts @ self.base_sae.W_dec
        if self.lora_dec is not None:
            lora_dec = self.lora_dec(acts)
            x_reconstruct = base_dec + lora_dec + self.base_sae.b_dec
        else:
            x_reconstruct = base_dec + self.base_sae.b_dec
        
        # Update inactive features and get loss
        self.base_sae.update_inactive_features(acts)
        return self.base_sae.get_loss_dict(
            x=x,
            x_reconstruct=x_reconstruct,
            acts=acts,
            acts_topk=acts_topk,
            x_mean=x_mean,
            x_std=x_std
        )

    def merge_and_unload(self):
        """Merge LoRA weights into base model"""
        with torch.no_grad():
            if self.lora_enc is not None:
                delta_enc = self.lora_enc.lora_up.weight @ self.lora_enc.lora_down.weight * self.lora_enc.scaling
                self.base_sae.W_enc.data += delta_enc.T
            
            if self.lora_dec is not None:
                delta_dec = self.lora_dec.lora_up.weight @ self.lora_dec.lora_down.weight * self.lora_dec.scaling
                self.base_sae.W_dec.data += delta_dec.T
        
        return self.base_sae
