import torch, random, math
from torch import nn
from pyvene import (
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
    CollectIntervention,
    InterventionOutput,
    SigmoidMaskIntervention,
)


class LowRankRotateLayer(torch.nn.Module):
    """A linear transformation with orthogonal initialization."""

    def __init__(self, n, m, init_orth=True):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
        if init_orth:
            torch.nn.init.orthogonal_(self.weight)

    def forward(self, x):
        return torch.matmul(x.to(self.weight.dtype), self.weight)


class TopKReLUSubspaceIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    """
    Phi(h) = (h - h@v) + Mean(TopK(ReLU(h@v)))*v
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"])
        with torch.no_grad():
            self.proj.bias.fill_(0)

    def forward(
        self, base, source=None, subspaces=None
    ):
        v = []
        if "subspaces" in subspaces:
            for subspace in subspaces["subspaces"]:
                v += [self.proj.weight[subspace]]
        else:
            for i in range(base.shape[0]):
                v += [self.proj.weight[0]]
        v = torch.stack(v, dim=0).unsqueeze(dim=-1) # bs, h, 1
        
        # get latent
        latent = torch.relu(torch.bmm(base, v)).squeeze(dim=-1) # bs, s, 1
        topk_acts, topk_indices = latent.topk(k=subspaces["k"], dim=-1, sorted=False)
        non_topk_latent = latent.clone()
        non_topk_latent.scatter_(-1, topk_indices, 0)

        # get orthogonal component
        proj_vec = torch.bmm(latent.unsqueeze(dim=-1), v.permute(0, 2, 1)) # bs, s, 1 * bs, 1, h = bs, s, h
        base_orthogonal = base - proj_vec

        # get steering magnitude using mean of topk activations of prompt latent
        max_mean_latent = topk_acts.mean(dim=-1, keepdim=True) # bs, 1
        # steering vector
        steering_vec = max_mean_latent.unsqueeze(dim=-1) * v.permute(0, 2, 1) # bs, 1, h

        # addition intervention
        output = base_orthogonal + steering_vec

        return InterventionOutput(
            output=output.to(base.dtype),
            latent=[latent, non_topk_latent, max_mean_latent]
        )


class TopKReLUIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    """
    Phi(h) = h + Mean(TopK(ReLU(h@v)))*v
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"])
        with torch.no_grad():
            self.proj.weight.fill_(0.01)
            self.proj.bias.fill_(0)

    def forward(
        self, base, source=None, subspaces=None
    ):
        v = []
        if "subspaces" in subspaces:
            for subspace in subspaces["subspaces"]:
                v += [self.proj.weight[subspace]]
        else:
            for i in range(base.shape[0]):
                v += [self.proj.weight[0]]
        v = torch.stack(v, dim=0).unsqueeze(dim=-1) # bs, h, 1
        
        # get latent
        latent = torch.relu(torch.bmm(base, v)).squeeze(dim=-1) # bs, s, 1
        topk_acts, topk_indices = latent.topk(k=subspaces["k"], dim=-1, sorted=False)
        non_topk_latent = latent.clone()
        non_topk_latent.scatter_(-1, topk_indices, 0)

        # get steering magnitude using mean of topk activations of prompt latent
        if "max_acts" in subspaces:
            max_mean_latent = subspaces["max_acts"] # bs, 1
        else:
            max_mean_latent = topk_acts.mean(dim=-1, keepdim=True) # bs, 1
        # steering vector
        steering_vec = max_mean_latent.unsqueeze(dim=-1) * v.permute(0, 2, 1) # bs, 1, h

        # addition intervention
        if "steering_factor" in subspaces:
            steering_factor = subspaces["steering_factor"].unsqueeze(dim=-1).unsqueeze(dim=-1) # bs, 1, 1
            output = base + steering_factor * steering_vec
        else:
            output = base + steering_vec

        return InterventionOutput(
            output=output.to(base.dtype),
            latent=[latent, non_topk_latent, max_mean_latent]
        )


class ConceptReFTIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    """
    Phi(h) = h + R^T(Wh + b - Rh)
    Ref: https://arxiv.org/pdf/2404.03592

    Note that this intervention is used for concept-based Direft.
    The main difference is that weights are assumed to be trained and saved as 3D tensors.
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.W_proj = nn.Parameter(torch.zeros(
            kwargs["n_concepts"], self.embed_dim, kwargs["low_rank_dimension"]))
        self.W_source = nn.Parameter(torch.zeros(
            kwargs["n_concepts"], self.embed_dim, kwargs["low_rank_dimension"]))
        self.b_source = nn.Parameter(torch.zeros(
            kwargs["n_concepts"], kwargs["low_rank_dimension"]))

    def encode(
        self, base, source=None, subspaces=None
    ):
        """High-dimensional concept space."""
        proj_weight = self.W_proj[subspaces["input_subspaces"]] # batch_size, embed_dim, low_rank_dimension
        rotated_base = torch.bmm(base, proj_weight) # [batch_size, seq_len, embed_dim] X [batch_size, embed_dim, low_rank_dimension]

        return rotated_base # batch_size, seq_len, low_rank_dimension

    def forward(
        self, base, source=None, subspaces=None
    ):
        proj_weight = self.W_proj[subspaces["idx"]] # batch_size, embed_dim, low_rank_dimension
        source_weight = self.W_source[subspaces["idx"]] # batch_size, embed_dim, low_rank_dimension
        source_bias = self.b_source[subspaces["idx"]].unsqueeze(dim=1) # batch_size, 1, low_rank_dimension

        rotated_base = torch.bmm(base.float(), proj_weight) # batch_size, seq_len, low_rank_dimension
        output = base + torch.bmm(
            ((torch.bmm(base, source_weight) + source_bias) - rotated_base), # batch_size, seq_len, low_rank_dimension
            proj_weight.transpose(-1, -2)
        )
        return output.to(base.dtype)
    

class AdditionIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        # Note that we initialise these to zeros because we're loading in pre-trained weights.
        # If you want to train your own SAEs then we recommend using blah
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["low_rank_dimension"], bias=True)

    def forward(self, base, source=None, subspaces=None):
        # use subspaces["idx"] to select the correct weight vector
        steering_vec = subspaces["max_act"].unsqueeze(dim=-1) * \
            (subspaces["mag"] + self.proj.bias[subspaces["idx"]]).unsqueeze(dim=-1) * self.proj.weight[subspaces["idx"]]
        output = base + steering_vec.unsqueeze(dim=1)
        return output
    
class AdditionSuppressionIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        # Note that we initialise these to zeros because we're loading in pre-trained weights.
        # If you want to train your own SAEs then we recommend using blah
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["low_rank_dimension"], bias=True)
        ## also add on bias

    def forward(self, base, source=None, subspaces=None):
        # use subspaces["idx"] to select the correct weight vector
        steering_vec = self.proj.weight[subspaces["idx"]]
        neg_mask = subspaces["mag"] <= 0 # bs, 1, 1, this is only for null it out training #this is when applying suppression
        pos_mask = subspaces["mag"] > 0.0 # bs, 1, 1 # this is when applying steering
        latent = (torch.bmm(base, steering_vec.unsqueeze(-1))+ self.proj.bias[subspaces["idx"]].unsqueeze(-1).unsqueeze(-1)) > 0 # bs, s, 1
        neg_steering_mask = torch.einsum("bsq, b->bsq", latent, neg_mask) # bs, s, 1 * bs, 1, 1 = bs, s, 1
        # When zero_mask is 1, multiply neg_steering_factor by subspaces["mag"]
        combined_steering_factor = torch.einsum("bsq, b->bs", neg_steering_mask, subspaces["mag"]) + torch.einsum("b, b->b", (subspaces["mag"] + self.proj.bias[subspaces["idx"]]), pos_mask).unsqueeze(-1) # bs, s, 1
        steering_vec = torch.einsum("bh, bs->bsh", steering_vec, combined_steering_factor) # bs, s, d
        output = base + steering_vec

        return output
    

class ThresholdingIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["low_rank_dimension"], bias=True)

    def forward(self, base, source=None, subspaces=None):
        steering_direction = self.proj.weight[subspaces["idx"]]  # [batch, embed_dim]
        projection = torch.bmm(
            base,  # [batch, seq_len, embed_dim]
            steering_direction.unsqueeze(-1)  # [batch, embed_dim, 1]
        ).squeeze(-1)  # [batch, seq_len]
        desired_magnitude = subspaces["max_act"] * subspaces["mag"]  # [batch]
        mask = (projection < desired_magnitude.unsqueeze(-1)).float()  # [batch, seq_len]
        steering_vec = desired_magnitude.unsqueeze(-1) * steering_direction  # [batch, embed_dim]
        masked_steering = mask.unsqueeze(-1) * steering_vec.unsqueeze(1)  # [batch, seq_len, embed_dim]
        output = base + masked_steering
        
        return output


class SigmoidMaskAdditionIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        # here, low_rank_dimension is the number of concepts in the SAE
        # we learn a mask over the concepts
        self.proj = torch.nn.Linear(
            kwargs["sae_width"], self.embed_dim, bias=True)
        self.mask = torch.nn.Parameter(
            torch.zeros(kwargs["low_rank_dimension"], kwargs["sae_width"]), requires_grad=True)
        self.source = torch.nn.Parameter(
            0.001 *torch.ones(kwargs["low_rank_dimension"], kwargs["sae_width"]), requires_grad=True)
        self.temperature = torch.nn.Parameter(torch.tensor(0.01), requires_grad=False)
    
    def get_temperature(self) -> torch.Tensor:
        return self.temperature

    def set_temperature(self, temp: torch.Tensor) -> None:
        self.temperature.data = temp
    
    def get_latent_weights(self) -> torch.Tensor:
        mask_sigmoid = torch.sigmoid(self.mask / torch.tensor(self.temperature))
        masked_source = (torch.relu(self.source) * mask_sigmoid).unsqueeze(0)
        return masked_source
    
    def forward(self, base, source=None, subspaces=None) -> torch.Tensor:
        # use subspaces["idx"] to select the correct weight vector
        masked_source = self.get_latent_weights()
        steering_vec = self.proj(masked_source)
        output = base + steering_vec.unsqueeze(dim=1)
        return output


class SubspaceIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"], bias=True)
    
    def forward(self, base, source=None, subspaces=None):
        v = self.proj.weight[subspaces["idx"]].unsqueeze(dim=-1) # bs, h, 1

        # get orthogonal component
        latent = torch.relu(torch.bmm(base, v)) # bs, s, 1
        proj_vec = torch.bmm(latent, v.permute(0, 2, 1)) # bs, s, 1 * bs, 1, h = bs, s, h
        base_orthogonal = base - proj_vec

        steering_scale = subspaces["max_act"].unsqueeze(-1).unsqueeze(-1) * \
            subspaces["mag"].unsqueeze(-1).unsqueeze(-1)
        steering_vec = steering_scale * v.permute(0, 2, 1) # bs, 1, h
        
        # Replace the projection component with the steering vector
        output = base_orthogonal + steering_vec 
        return output


class DictionaryAdditionIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    """
    Anthropic's intervention method. 
    
    For smaller models, we just gave up on this ...
    But feel free to try it and see if it works for you.
    """
    def __init__(self, **kwargs):
        # Note that we initialize these to zeros because we're loading in pre-trained weights.
        # If you want to train your own SAEs then we recommend using appropriate initialization.
        super().__init__(**kwargs, keep_last_dim=True)
        self.W_enc = nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"]))
        self.W_dec = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim))
        self.threshold = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
        self.b_enc = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
        self.b_dec = nn.Parameter(torch.zeros(self.embed_dim))
    
    def encode(self, input_acts):
        pre_acts = torch.matmul(input_acts, self.W_enc) + self.b_enc  # Shape: [batch_size, seq_len, low_rank_dimension]
        mask = (pre_acts > self.threshold)  # Shape: [batch_size, seq_len, low_rank_dimension]
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        reconstructed = torch.matmul(acts, self.W_dec) + self.b_dec  # Shape: [batch_size, seq_len, embed_dim]
        return reconstructed

    def forward(self, base, source=None, subspaces=None):
        """
        base: Residual stream activity x, shape [batch_size, seq_len, embed_dim]
        subspaces: Dictionary containing 'idx' and 'mag'
        """
        acts = self.encode(base)
        SAE_x = self.decode(acts)
        error_x = base - SAE_x
        
        acts_modified = acts.clone()
        feature_acts = subspaces['mag'] * subspaces["max_act"]
        acts_modified[:, :, subspaces['idx']] = feature_acts.to(base.dtype)

        modified_SAE_x = self.decode(acts_modified)
        x_new = modified_SAE_x + error_x 

        return x_new


class DictionaryMinClampingIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    """
    Anthropic's intervention method. 
    
    For smaller models, we just gave up on this ...
    But feel free to try it and see if it works for you.
    """
    def __init__(self, **kwargs):
        # Note that we initialize these to zeros because we're loading in pre-trained weights.
        # If you want to train your own SAEs then we recommend using appropriate initialization.
        super().__init__(**kwargs, keep_last_dim=True)
        self.W_enc = nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"]))
        self.W_dec = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim))
        self.threshold = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
        self.b_enc = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
        self.b_dec = nn.Parameter(torch.zeros(self.embed_dim))
    
    def encode(self, input_acts):
        pre_acts = torch.matmul(input_acts, self.W_enc) + self.b_enc  # Shape: [batch_size, seq_len, low_rank_dimension]
        mask = (pre_acts > self.threshold)  # Shape: [batch_size, seq_len, low_rank_dimension]
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        reconstructed = torch.matmul(acts, self.W_dec) + self.b_dec  # Shape: [batch_size, seq_len, embed_dim]
        return reconstructed

    def forward(self, base, source=None, subspaces=None):
        """
        base: Residual stream activity x, shape [batch_size, seq_len, embed_dim]
        subspaces: Dictionary containing 'idx' and 'mag'
        """
        acts = self.encode(base)
        SAE_x = self.decode(acts)
        error_x = base - SAE_x
        
        acts_modified = acts.clone()
        proposed_feature_acts = subspaces['mag'] * subspaces["max_act"]

        # minimum is current value if it is positive
        acts_modified[:, :, subspaces['idx']] = torch.max(
            acts[:, :, subspaces['idx']], proposed_feature_acts).to(base.dtype)

        modified_SAE_x = self.decode(acts_modified)
        x_new = modified_SAE_x + error_x 

        return x_new


class JumpReLUSAECollectIntervention(
    CollectIntervention
):
    """To collect SAE latent activations"""
    def __init__(self, **kwargs):
        # Note that we initialise these to zeros because we're loading in pre-trained weights.
        # If you want to train your own SAEs then we recommend using blah
        super().__init__(**kwargs, keep_last_dim=True)
        self.W_enc = nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"]))
        self.W_dec = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim))
        self.threshold = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
        self.b_enc = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
        self.b_dec = nn.Parameter(torch.zeros(self.embed_dim))
    
    def forward(self, base, source=None, subspaces=None):
        pre_acts = base @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts
    

class ProbeIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"])

    def forward(
        self, base, source=None, subspaces=None
    ):
        v = []
        if "subspaces" in subspaces:
            for subspace in subspaces["subspaces"]:
                v += [self.proj.weight[subspace]]
        else:
            for i in range(base.shape[0]):
                v += [self.proj.weight[0]]
        v = torch.stack(v, dim=0).unsqueeze(dim=-1) # bs, h, 1
        
        # get latent
        latent = torch.bmm(base, v).squeeze(dim=-1) # bs, s

        return InterventionOutput(
            output=base,
            latent=[latent]
        )
    

class SparseProbeIntervention(
    # We still inherit from these classes to keep it as close as possible to the LsReFT impl.
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"])
        with torch.no_grad():
            self.proj.weight.fill_(0.01)
            self.proj.bias.fill_(0)

    def forward(
        self, base, source=None, subspaces=None
    ):
        v = []
        if "subspaces" in subspaces:
            for subspace in subspaces["subspaces"]:
                v += [self.proj.weight[subspace]]
        else:
            for i in range(base.shape[0]):
                v += [self.proj.weight[0]]
        v = torch.stack(v, dim=0).unsqueeze(dim=-1) # bs, h, 1
        
        # get latent
        latent = torch.relu(torch.bmm(base, v)).squeeze(dim=-1) # bs, s, 1
        topk_acts, topk_indices = latent.topk(k=subspaces["k"], dim=-1, sorted=False)
        non_topk_latent = latent.clone()
        non_topk_latent.scatter_(-1, topk_indices, 0)

        # get steering magnitude using mean of topk activations of prompt latent
        max_mean_latent = topk_acts.mean(dim=-1, keepdim=False) # bs

        return InterventionOutput(
            output=base,
            latent=[max_mean_latent, non_topk_latent, latent]
        )
    

class SteeringVectorIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    """
    Phi(h) = h + v
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"])
        with torch.no_grad():
            self.proj.weight.fill_(0.01)
            self.proj.bias.fill_(0)

    def forward(
        self, base, source=None, subspaces=None
    ):
        v = []
        if "subspaces" in subspaces:
            for subspace in subspaces["subspaces"]:
                v += [self.proj.weight[subspace]]
        else:
            for i in range(base.shape[0]):
                v += [self.proj.weight[0]]
        v = torch.stack(v, dim=0).unsqueeze(dim=-1) # bs, h, 1
        latent = torch.relu(torch.bmm(base, v)).squeeze(dim=-1) # bs, s, 1
        steering_vec = v.permute(0, 2, 1) # bs, 1, h

        # addition intervention
        output = base + steering_vec

        return InterventionOutput(
            output=output.to(base.dtype),
            latent=[latent]
        )
    

class ConceptVectorIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    """
    Phi(h) = h + v
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"])

    def forward(
        self, base, source=None, subspaces=None
    ):
        v = []
        if "subspaces" in subspaces:
            for subspace in subspaces["subspaces"]:
                v += [self.proj.weight[subspace]]
        else:
            for i in range(base.shape[0]):
                v += [self.proj.weight[0]]
        v = torch.stack(v, dim=0).unsqueeze(dim=-1) # bs, h, 1
        latent = torch.relu(torch.bmm(base, v)).squeeze(dim=-1) # bs, s, 1
        steering_vec = v.permute(0, 2, 1) # bs, 1, h

        # addition intervention
        if "steering_factor" in subspaces:
            steering_factor = subspaces["steering_factor"].unsqueeze(dim=-1).unsqueeze(dim=-1) # bs, 1, 1
            output = base + steering_factor * steering_vec
        else:
            output = base + steering_vec

        return InterventionOutput(
            output=output.to(base.dtype),
            latent=[latent]
        )
    

class PreferenceVectorIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"])
        dropout = kwargs.get("dropout", 0.0)
        self.dropout = torch.nn.Dropout(dropout)
        self.intervention_positions_dropout = kwargs.get("intervention_positions_dropout", 0.0)
        with torch.no_grad():
            self.proj.bias.fill_(0)

    def forward(
        self, base, source=None, subspaces=None
    ):
        v = []
        if "subspaces" in subspaces:
            for subspace in subspaces["subspaces"]:
                v += [self.proj.weight[subspace]]
        else:
            for i in range(base.shape[0]):
                v += [self.proj.weight[0]]
        v = torch.stack(v, dim=0).unsqueeze(dim=-1) # bs, h, 1
        v_norm = torch.norm(v, dim=1, keepdim=True) # bs, 1, 1
        latent = torch.relu((torch.bmm(base, v) + self.proj.bias).squeeze(dim=-1)) # bs, s, 1
        steering_vec = v.permute(0, 2, 1) # bs, 1, h
        steering_vec = self.dropout(steering_vec)
        
        if "steering_factor" in subspaces:
            steering_factor = subspaces["steering_factor"].unsqueeze(dim=-1).unsqueeze(dim=-1) # bs, 1, 1
            zero_mask = steering_factor == 0.0 # bs, 1, 1, this is only for null it out training
            nonzero_mask = steering_factor != 0.0 # bs, 1, 1
            # h - (h@v)/||v||^2 * v, steering coefficient is (h@v)/||v||^2
            null_it_out_steering_factor = -(latent.unsqueeze(dim=-1) / v_norm**2)*zero_mask # bs, s, 1 * bs, 1, 1 = bs, s, 1
            combined_steering_factor = null_it_out_steering_factor + (steering_factor + self.proj.bias*nonzero_mask) # bs, s, 1
            # apply position based dropout
            dropout_mask = torch.rand_like(combined_steering_factor.float()) > self.intervention_positions_dropout
            combined_steering_factor *= dropout_mask
            steering_vec = steering_vec * combined_steering_factor # bs, s, d
        output = base + steering_vec

        return InterventionOutput(
            output=output.to(base.dtype),
            latent=[latent]
        )
    

class LoraIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    """
    LoRA(h') = h' + BAx
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.r = kwargs["low_rank_dimension"]
        self.lora_alpha = kwargs["alpha"] if "alpha" in kwargs else kwargs["low_rank_dimension"]
        if "dropout" in kwargs and kwargs["dropout"] > 0.0:
            self.lora_dropout = nn.Dropout(p=kwargs["dropout"])
        else:
            self.lora_dropout = lambda x: x

        # Actual trainable parameters
        self.lora_A = nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"]))
        self.lora_B = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim))

        # initialize A the same way as the default for nn.Linear and B to zero
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        self.lora_A = nn.Parameter(self.lora_A.to(torch.bfloat16))
        self.lora_B = nn.Parameter(self.lora_B.to(torch.bfloat16))

    def forward(
        self, base, source=None, subspaces=None, **kwargs
    ):
        original_input = kwargs["args"][0]
        return base + self.lora_dropout(original_input) @ self.lora_A @ self.lora_B