from Requierments import *

def zero_balanced_MTL(sigma, input_dim, shared_dim, output_dim_1, output_dim_2, aligned=None, conflict=False, ortho=False):
  W1_init=np.random.normal(0,sigma, (output_dim_1, shared_dim))

  if aligned==False and ortho==False:
    W1_init=W1_init/np.linalg.norm(W1_init, axis=1, keepdims=True)
    W2_init=np.random.normal(0,sigma, (output_dim_2, shared_dim))
    W2_init=W2_init-np.matmul(W1_init, np.dot(W1_init.T, W2_init))
    W2_init=W2_init/np.linalg.norm(W2_init)

  elif aligned==True:
    W1_init=W1_init/np.linalg.norm(W1_init)
    W1_init=W1_init.copy()
    W2_init=W1_init.copy()
  elif ortho==True:
    W_full,_=np.linalg.qr(np.random.randn(output_dim_1+output_dim_2, shared_dim))
    W1_init=W_full[:output_dim_1, :]
    W2_init=W_full[output_dim_1:output_dim_1+output_dim_2, :]
  else:
    W2_init=np.random.normal(0,sigma, (output_dim_2, shared_dim))

  M=W1_init.T@W1_init+W2_init.T@W2_init
  eigvals, U= np.linalg.eigh(M)

  eigvals=np.clip(eigvals, 0,None)

  sqrt_Lambda= np.diag(np.sqrt(eigvals*1))

  R=U @sqrt_Lambda
  #print(R.shape)

  print("notreach")
  if input_dim == shared_dim:
      W_S=R
  elif input_dim > shared_dim:
      O, _ = np.linalg.qr(np.random.randn(input_dim, input_dim))
      Q = O[:shared_dim, :]
      W_S =  R @ Q
  elif input_dim < shared_dim:
      O, _=np.linalg.qr(np.random.rand(shared_dim, shared_dim))
      Q=O[:, :input_dim]
      W_S=R@Q



  print("Verification:", np.allclose(W_S @ W_S.T, W1_init.T @ W1_init + W2_init.T @ W2_init))
  LL=W_S @ W_S.T
  RL= W1_init.T @ W1_init + W2_init.T @ W2_init
  print("||W_S W_S^T - (W1^T W1 + W2^T W2)|| =",np.linalg.norm(LL-RL))
  W_S=torch.tensor(W_S).float()
  W1_init=torch.tensor(W1_init).float()
  W2_init=torch.tensor(W2_init).float()

  return  W_S, W1_init, W2_init
