from Requierments import *
from Balanced_Inits import *

class QQT_MTL:

  def __init__(self, WS,W1,W2, X, sigma_1, sigma_2, weightsonly=False, U1=None, U2=None, S1=None, S2=None, V1=None, V2=None, sim=False):
    self.n_i = n_i = X.shape[1]
    self.n_o1=W1.shape[0]
    self.n_o2=W2.shape[0]
    self.batchsize=X.shape[0]
    self.weightsonly=weightsonly

    XI=np.concatenate((sigma_1, sigma_2), axis=0)
    
    if sim==True:
      U1, S1, V1T=np.linalg.svd(sigma_1, full_matrices=False)
      U2, S2, V2T=np.linalg.svd(sigma_2, full_matrices=False)

  


    U_Xi, S_Xi, Vt_Xi=np.linalg.svd(XI)

  

    V_Xi=Vt_Xi.T
    self.U_Xi, self.S_Xi, self.V_Xi=U_Xi, S_Xi,V_Xi

    dim_diff=np.abs((self.n_o1+self.n_o2) - self.n_i)
    #then eventually padd to compensate for input/output dif.
    if self.n_i < self.n_o1 + self.n_o2:
      U_perp=U_Xi[:, n_i:] #orthonormal vectors
      V_perp=np.zeros((n_i, dim_diff))
      U_Xi=U_Xi[:, :self.n_i] #so now if concact U_xi U_perp we have orthonormality and nice dimension
    elif self.n_i > self.n_o1 + self.n_o2:
      U_perp=np.zeros((self.n_o1+self.n_o2, dim_diff))
      V_perp=V_Xi[:,self.n_o1+self.n_o2: ]
      V_Xi=V_Xi[:, :self.n_o1+self.n_o2]
    else:
      U_perp=V_perp=None

    self.U_perp=U_perp
    self.V_perp=V_perp
    self.U_Xi, self.S_Xi, self.V_Xi=U_Xi, np.diag(S_Xi), V_Xi


    print("ranksxi", np.linalg.matrix_rank(self.S_Xi))
    init_mat=np.concatenate((W1@WS, W2@WS), axis=0) #in analogy to what we did with correlation matrices, we need to combine the weights for both tasks in one matrix to make the dimensions work out.
    U_init, S_init, Vt_init=np.linalg.svd(init_mat, False
                                          )
    V_init=Vt_init.T
    self.U_init, self.S_init, self.V_init=U_init, np.diag(S_init), V_init
    #check if solution exists (invertability condition)
    B=U_init.T@U_Xi + V_init.T@V_Xi
    self.B=B
    assert np.abs(np.linalg.det(B)) > 1e-10, f"B is not invertible det(B) = {np.linalg.det(B)}"
    self.B_inv=np.linalg.inv(B)
    self.C=U_init.T@U_Xi-V_init.T@V_Xi

    self.i=np.identity(self.n_i) if self.n_i < self.n_o1 + self.n_o2 else np.identity(self.n_o1+self.n_o2) #still ned to work out how this works exactly for mtl

    self.t=0 #timestep

  def return_SVDxi(self):
    return self.U_Xi, self.S_Xi, self.V_Xi




  def time_step_nondiag(self, learning_rate):

    tau=1./learning_rate

    fract=self.t/tau

    U_init, S_init, V_init= self.U_init, self.S_init, self.V_init
    U_Xi, S_Xi, V_Xi=self.U_Xi, self.S_Xi, self.V_Xi
    U_perp, V_perp=self.U_perp, self.V_perp


    C=self.C
    B_inv=self.B_inv

    exp_Lambda=np.diag(np.exp(-1*np.diag(S_Xi)*fract))
    exp_2_Lambda=np.diag(np.exp(-2*np.diag(S_Xi)*fract))

    S_int=np.diag(S_init)
    tol = 1e-8
    S_init_inv=np.linalg.pinv(np.diag(S_int))
    S_XI_inv=np.diag(1./np.diag(S_Xi))

    i=self.i
    if U_perp is None and V_perp is None:

      Z= np.vstack([V_Xi@(i-exp_Lambda@C.T @B_inv.T @exp_Lambda), U_Xi@(i+exp_Lambda @C.T @B_inv.T @exp_Lambda)])

      center_right=0

    else:
      Z = np.vstack([
                V_Xi @ (i - exp_Lambda @ C.T @ B_inv.T @ exp_Lambda) + 2. * V_perp @ V_perp.T @ V_init @ B_inv.T @ exp_Lambda,
                U_Xi @ (i + exp_Lambda @ C.T @ B_inv.T @ exp_Lambda) + 2. * U_perp @ U_perp.T @ U_init @ B_inv.T @ exp_Lambda])
      center_right = 4 * fract * exp_Lambda @ B_inv @ (V_init.T @ V_perp @ V_perp.T @ V_init + U_init.T @ U_perp @ U_perp.T @ U_init) @ B_inv.T @ exp_Lambda

    center_left=4.*exp_Lambda @ B_inv@ S_init_inv @ B_inv.T @exp_Lambda #watch out for S_INit_inv
    center_center=(i-exp_2_Lambda)@ S_XI_inv -exp_Lambda @ B_inv @C @(exp_2_Lambda-i) @S_XI_inv @C.T@ B_inv.T@ exp_Lambda
    center=np.linalg.inv(center_left+center_center+center_right)

    qqt=Z@center@Z.T

    if self.weightsonly:
      qqt=qqt[self.n_i:, :self.n_i]

    self.t+=1

    return qqt


def compute_analytical_loss_from_plots(analytical_plots, X_train, y1_train, y2_train, input_dim, task1_dim, task2_dim, epochs):
    loss_task1_ana = []
    loss_task2_ana = []
    loss_total_ana = []
    criterion = nn.MSELoss()

    for t in range(epochs):
        # Reconstruct W1_ana and W2_ana for epoch t by reshaping
        W1_ana_flat = analytical_plots[3][t]  # from W1 @ WS analytical (flattened)
        W2_ana_flat = analytical_plots[6][t]  # from W2 @ WS analytical (flattened)
        # Reshape to (task_dim, input_dim)
        W1_ana = W1_ana_flat.reshape(task1_dim, input_dim)
        W2_ana = W2_ana_flat.reshape(task2_dim, input_dim)


        y1_pred = torch.tensor(W1_ana, dtype=torch.float32) @ X_train.T
        y2_pred = torch.tensor(W2_ana, dtype=torch.float32) @ X_train.T


        loss1 = criterion(torch.tensor(y1_pred.T, dtype=torch.float32), y1_train.T)
        loss2 = criterion(torch.tensor(y2_pred.T, dtype=torch.float32), y2_train.T)

        loss_task1_ana.append(loss1.item())
        loss_task2_ana.append(loss2.item())
        loss_total_ana.append((0.5*loss1 + 0.5*loss2).item())

    return loss_task1_ana, loss_task2_ana, loss_total_ana