
import math
import torch
import torch.nn as nn
import numpy as np
from torch.nn.modules.utils import _pair
from scipy import ndimage
import torch.nn.functional as F
from models import configs
from tool.create_orthgonal import create_precise_random_standard_orthogonal_matrix


class new_orth_Linear(nn.Module):
    def __init__(self, size_in, size_out, bias=True, enable_lora=False, FFN=False):
        super(new_orth_Linear, self).__init__()
        self.enable_lora = enable_lora
        self.FFN = FFN
        self.size_in = size_in
        self.size_out = size_out
        self.has_bias = bias
        self.mlp = nn.Linear(size_in, size_out, bias=bias)
        self.rank = 1
        self.num = min(size_in, size_out)
        if self.enable_lora:
            self.Lora_A = nn.Parameter(torch.empty(self.size_in, self.rank), requires_grad=True)
            self.Lora_B = nn.Parameter(torch.empty(self.rank, self.size_out), requires_grad=True)
            self.scaling = 1 / self.rank
            self.r_house = nn.Parameter(torch.Tensor(self.num, 1), requires_grad=True)

            nn.init.kaiming_uniform_(self.Lora_A)
            nn.init.ones_(self.r_house)
            nn.init.zeros_(self.Lora_B)
        self._frozen_param()

    def using_orth_init_lora(self, orth):
        with torch.no_grad():
            self.Lora_A.copy_(orth.T[:rank])

    def _frozen_param(self):
        for param in self.mlp.parameters():
            param.requires_grad = False

    def forward(self, x, orth=None):
        if self.enable_lora:
            U, V = self.get_orth_tensor(orth)
            r = torch.squeeze(self.r_house)
            R = torch.diag(r)
            weight = U @ R @ V
            result = self.mlp(x)
            result += (x @ self.Lora_A.transpose(0, 1) @ self.Lora_B.transpose(0, 1)) * self.scaling
            result += x @ weight
            return result
        else:
            return self.mlp(x)

    def get_orth_tensor(self, orth):
        assert orth is not None, f"error orth matrix, please check the enable of orth"
        U = None
        V = None
        if self.size_in < self.size_out:
            U = orth.T
            V = torch.cat([orth] * 4, dim=1) / 2
            V = V.contiguous()
        elif self.size_in > self.size_out:
            U = torch.cat([orth.T] * 4, dim=0) / 2
            U = U.contiguous()
            V = orth
        else:
            U = orth.T
            V = orth

        return U, V


