import torch

def inference_mode(model):
    model = model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model

class StitchingLayer(torch.nn.Module):
    def __init__(self, d_A, d_B, device='cuda', use_bias=False, force_orthogonal=False, norm_adjustment=False):
        super().__init__()
        self.device = device
        self.use_bias = use_bias
        self.d_A = d_A
        self.d_B = d_B
        if force_orthogonal:
            self.projection = torch.nn.utils.parametrizations.orthogonal(
                torch.nn.Linear(d_A, d_B, bias=use_bias, device=self.device)
            )
        else:
            self.projection = torch.nn.Linear(d_A, d_B, bias=use_bias, device=self.device)
        if use_bias:
            torch.nn.init.zeros_(self.projection.bias)
        if norm_adjustment is None:
            norm_adjustment = False
        self.norm_adjustment = norm_adjustment
        if norm_adjustment:
            self.beta = torch.ones(1, device=self.device, requires_grad=False)
        else:
            self.beta = torch.ones(1, device=self.device, requires_grad=False)

    def forward(self, residual_stream_a):
        projected_residual_stream_b = self.projection(residual_stream_a) * self.beta
        return projected_residual_stream_b

class BidirectionalStitchingLayer(StitchingLayer):
    def __init__(self, d_A, d_B, device='cuda', use_bias1=False, use_bias2=False, force_orthogonal=False, method='transpose', norm_adjustment=False):
        super().__init__(d_A, d_B, device, use_bias1, force_orthogonal, norm_adjustment)
        self.method = method
        if method == 'separate_mat':
            self.use_bias2 = use_bias2
            self.projection2 = torch.nn.Linear(d_B, d_A, bias=use_bias2, device=self.device)
            if self.use_bias2:
                torch.nn.init.zeros_(self.projection2.bias)

    def inv_projection(self, x):
        # by default, tie to the original bias
        xproj = None
        if self.method == 'transpose':
            kernel = self.projection.weight
            # tied the bias
            if self.use_bias:
                xproj = (x / self.beta - self.projection.bias) @ kernel
            else:
                xproj = x @ kernel / self.beta

        elif self.method == 'pinv':
            kernel = torch.pinverse(self.projection.weight.T)
            # tied the bias
            if self.use_bias:
                xproj = (x / self.beta - self.projection.bias) @ kernel
            else:
                xproj = x @ kernel / self.beta
        elif self.method == 'separate_mat':
            if self.use_bias and not self.use_bias2:
                # if only have the old bias, tie it
                xproj = self.projection2((x / self.beta) - self.projection.bias)
            else:
                # otherwise, we got no biases
                # OR we are using a separate bias
                # OR we have no old bias and only a new bias, all are equivalent
                xproj = self.projection2(x / self.beta)
        return xproj

    def forward(self, residual_stream_a, residual_stream_b=None, run_inverse=False):
        projected_residual_stream_b = self.projection(residual_stream_a) * self.beta
        projected_residual_stream_a = self.inv_projection(residual_stream_b)
        if run_inverse:
            inv_residual_stream_a = self.inv_projection(projected_residual_stream_b)
            inv_residual_stream_b = self.projection(projected_residual_stream_a) * self.beta
            return projected_residual_stream_a, projected_residual_stream_b,inv_residual_stream_a, inv_residual_stream_b
        else:
            return projected_residual_stream_a, projected_residual_stream_b
