from Requierments import *


def generate_teacher_student_tasks(
    input_dim=10,
    shared_dim=5,
    output_dim_1=2,
    output_dim_2=2,
    aligned=True,
    num_samples=10,
    seed=42
):
    np.random.seed(seed)

    # Create W1
    W1_candidate = np.random.randn(output_dim_1, shared_dim)
    W1, _ = np.linalg.qr(W1_candidate.T)
    W1 = W1.T  # shape: (output_dim, shared_dim)

    # Create W2 (aligned or orthogonal to W1)
    if aligned:
        W2 = W1.copy()
    else:
        P_orth = np.eye(shared_dim) - W1.T @ W1
        W2_candidate = np.random.randn(output_dim_2, shared_dim)
        W2_proj = W2_candidate @ P_orth
        W2, _ = np.linalg.qr(W2_proj.T)
        W2 = W2.T  # shape: (output_dim, shared_dim)

    # Step 3: Build M = W1^T W1 + W2^T W2
    M = W1.T @ W1 + W2.T @ W2

    # Step 4: Build W_S such that W_S^T W_S ≈ M
    eigvals, U = np.linalg.eigh(M)
    eigvals = np.clip(eigvals, 0, None)
    sqrt_Lambda = np.diag(np.sqrt(eigvals))
    R = U @ sqrt_Lambda  # shape: (shared_dim, shared_dim)

    if input_dim >= shared_dim:
        O, _ = np.linalg.qr(np.random.randn(input_dim, input_dim))
        Q = O[:shared_dim, :]
    else:
        O, _ = np.linalg.qr(np.random.randn(shared_dim, shared_dim))
        Q = O[:, :input_dim]

    W_S = R @ Q  # shape: (shared_dim, input_dim)

    # Step 5: Create whitened input X
    X = np.random.randn(num_samples, input_dim)

    # Step 6: Compute task outputs
    Y1 = W1 @ W_S @ X.T  # shape: (output_dim, num_samples)
    Y2 = W2 @ W_S @ X.T  # shape: (output_dim, num_samples)

    # Step 7: Compute Sigma_1o and Sigma_2o
    Sigma1o = Y1 @ X / num_samples  # (output_dim, input_dim)
    Sigma2o = Y2 @ X / num_samples

    # Step 8: Compute cosine similarity between flattened gradient-like directions
    flat1 = Sigma1o.flatten()
    flat2 = Sigma2o.flatten()
    cos_sim = np.dot(flat1, flat2) / (np.linalg.norm(flat1) * np.linalg.norm(flat2))

    # Step 9: Return results
    return W_S, W1, W2


def MTL_tasks(input_dim=5, task1_dim=1, task2_dim=1, num_samples=100, alpha_scales=(1,1), seed=55, alignment_factor=1): #alignment factor=1 is perfect alignment, 0=orthogonal, -1=completely disaligned
    np.random.seed(seed)
    X=np.random.randn(num_samples,input_dim)
    X=whiten_input(X, numpy=True)
    
    rng = np.random.default_rng()

    # 1. build an orthonormal basis in ℝ^{d_in}
    Q_full, _ = np.linalg.qr(rng.standard_normal((input_dim, input_dim)))
    W1 = Q_full[:, :task1_dim].T        # shape (100, 250)
    W2 = alignment_factor*W1 + (1-alignment_factor)*Q_full[:, task1_dim:2*task1_dim].T # shape (outputdim, inputdum)
    W2 = W2 / np.linalg.norm(W2, 'fro') * np.linalg.norm(W1, 'fro')
        
    Y1 = alpha_scales[0]*W1 @ X.T      
    Y2 = alpha_scales[1]*W2 @ X.T
    
    if task1_dim == 1:
      Y1=Y1.reshape(1, num_samples)
      Y2=Y2.reshape(1, num_samples)
    
    return X, Y1, Y2
    


def generate_regression_tasks(
    input_dim:   int,         # D, number of features
    task1_dim:   int,         # N1, outputs of task 1
    task2_dim:   int,         # N2, outputs of task 2
    num_samples: int = 1,     # how many (B1,B2) pairs to draw
    rho:         float = 0.0, # desired E[cos(vec B1, vec B2)]
    alphas:      tuple = (1.0, 1.0),  # (α1, α2) scales for each task
    noise_std:   float = 0.0, # STD of noise added to each B
    seed:        int   = None,
    sigma_eps:   float=0.5,
    normalize:   bool = False,
    aligned:     bool = True
):
    """
    Returns:
      B1s : (num_samples, task1_dim, input_dim)
      B2s : (num_samples, task2_dim, input_dim)

    so that for each sample s,
      vec(B1s[s]) and vec(B2s[s]) have E[cos(vec(B1),vec(B2))]=rho,
    then B1s<-α1*B1s + noise, B2s<-α2*B2s + noise.
    """
    print("seed", seed)
    if seed is not None:
        np.random.seed(seed)


    X=np.random.randn(num_samples,input_dim)
    X=whiten_input(X, numpy=True)

    N1, N2, D = task1_dim, task2_dim, input_dim
    α1, α2 = alphas
    Nmax = max(N1, N2)

    B=np.random.normal(0, 3, size=(num_samples, task1_dim))

    eps1=np.random.normal(0, sigma_eps, size=(num_samples, task1_dim))

    eps2=np.random.normal(0, sigma_eps, size=(num_samples, task1_dim))

    if aligned:
      Y1=α1*(B)
      Y2=α2*(B+eps2)
    else:
      Y1=α1*(B)
      Y2= α2*(eps2)
    

    if noise_std > 0:
        Y2 += 1*np.random.normal(0,noise_std, size=B.shape)
    
    if normalize == True:
      Y1=Y1/np.linalg.norm(Y1, 'fro')
      Y2=Y2/np.linalg.norm(Y2, 'fro')

    return X, Y1, Y2


class MultiMNISTGenerator:
    def __init__(self, data, labels, train=True, shift_range=4, pairs_per_image=10, alignment=True, max_iou=1,corrupt_p=0.0,              # <── new
                 random=None): #shift controls the overlap shift=0 is perfect overlap, 1 almost perfect, 4, medium, 8 can be disjoint
        # Use first 60k for train, last 10k for test
        N = 50000 if train else 1000
        self.data   = data[:N]
        self.labels = labels[:N]
        self.shift_range     = shift_range
        self.pairs_per_image = pairs_per_image
        self.base_size = 28
        self.multi_size = 36
        self.alignment=alignment
        self.max_iou=max_iou
        self.corrupt_p=corrupt_p
        self.random=np.random.default_rng(random)


        # Precompute indices by class
        self.by_class = {c: np.where(self.labels==c)[0] for c in range(10)}
        self.total = N * pairs_per_image

        if not self.alignment:
            # make a random orthonormal 36×36 matrix
            Q, _ = np.linalg.qr(np.random.randn(self.multi_size, self.multi_size))
            self.R = Q
        else:
            self.R = np.eye(self.multi_size)

    def __len__(self):
        return self.total

    def __iter__(self):
        for idx in range(self.total):
            img_idx = idx // self.pairs_per_image
            img1 = self.data[img_idx].reshape(28,28)
            y1   = self.labels[img_idx]
            # pick a different class
            other_class = np.random.choice([c for c in range(10) if c!=y1])
            partner_idx = np.random.choice(self.by_class[other_class])
            img2 = self.data[partner_idx].reshape(28,28)
            y2   = other_class
            #if self.random() < self.corrupt_p:
             # wrong_choices = [c for c in range(len(self.t2_classes)) if c != y2]
              #y2 = self.rng.choice(wrong_choices)

            # make canvases
            canvas1 = np.zeros((36,36),dtype=np.float32)
            canvas2 = np.zeros((36,36),dtype=np.float32)

            sh1, sh2 = self._sample_shifts(self.max_iou, self.shift_range)

            canvas1[sh1[0]:sh1[0]+28, sh1[1]:sh1[1]+28] = img1/255.
            canvas2[sh2[0]:sh2[0]+28, sh2[1]:sh2[1]+28] = img2/255.
            if not self.alignment:
              canvas2 = self.R @ canvas2 @ self.R.T
            combined = np.maximum(canvas1, canvas2)

            yield combined, (y1, y2)


    def _iou(self, sh1, sh2):
        """Intersection-over-Union of the two 28×28 boxes."""
        dy = max(0, 28 - abs(sh1[0] - sh2[0]))
        dx = max(0, 28 - abs(sh1[1] - sh2[1]))
        inter = dy * dx
        union = 28*28*2 - inter

        return inter / union           # 0 = none, 1 = perfect overlap


    def _sample_shifts(self, max_iou, shift_range):
      """Draw shifts whose IoU ≤ max_iou (small value = little overlap)."""
      while True:
        self.max_iou=max_iou
        sh1 = np.random.randint(0, shift_range+1, 2)
        sh2 = np.random.randint(0, shift_range+1, 2)
        if self._iou(sh1, sh2) <= max_iou:
            return sh1, sh2

    def batchmaker(self, batchsize, generator):
        batch = []
        X_train=[]
        y1_train=[]
        y2_train=[]
        n_classes=10
        for i, (img, (y1,y2)) in enumerate(generator):
          y_1ohe = np.eye(n_classes, dtype=np.float32)[y1]
          y_2ohe = np.eye(n_classes, dtype=np.float32)[y2]
          batch.append((img,(y1,y2)))
          X_train.append(img)
          y1_train.append(y_1ohe)
          y2_train.append(y_2ohe)
          if i>=batchsize: break

        X_Train=np.stack(torch.tensor(X_train).reshape(len(X_train), 36*36).float()) #flatten X
        eps = 1e-5                 # numerical-stability constant
        mu = X_Train.mean(axis=0, keepdims=True)      # (1, 1296)
        X_centered      = X_Train - mu                             # zero-mean
        N = X_centered.shape[0]
        cov = (X_centered.T @ X_centered) / (N - 1)
        eigvals, eigvecs = np.linalg.eigh(cov)
        D_inv_sqrt = np.diag(1.0 / np.sqrt(eigvals + eps))
        whitening_matrix = eigvecs @ D_inv_sqrt @ eigvecs.T
        X_whitened = X_centered @ whitening_matrix
        cov_whitened = (X_whitened.T@ X_whitened) / (N - 1)
        print(np.trace(cov_whitened))
        X_Train=torch.tensor(X_whitened).float()
        y1_train=torch.tensor(y1_train).float()
        y2_train=torch.tensor(y2_train).float()


        return X_Train, y1_train, y2_train



def whiten_input(X, numpy=False):
  if numpy==False:
    X_mean = X.mean(dim=0, keepdim=True)
    X_centered = X - X_mean

# 2. Compute the covariance matrix of the centered data.
    N = X_centered.shape[0]
    cov = (X_centered.t() @ X_centered) / (N - 1)

# 3. Perform eigen-decomposition of the covariance matrix.
    eigvals, eigvecs = torch.linalg.eigh(cov)

# For numerical stability, add a small epsilon to eigenvalues when taking the inverse square root.
    eps = 1e-5
    D_inv_sqrt = torch.diag(1.0 / torch.sqrt(eigvals + eps))

# 4. Compute the whitening matrix.
    whitening_matrix = eigvecs @ D_inv_sqrt @ eigvecs.t()

# 5. Transform the centered data using the whitening matrix.
    X_whitened = X_centered @ whitening_matrix

# Verification: the covariance of X_whitened should be close to the identity matrix.
    cov_whitened = (X_whitened.t() @ X_whitened) / (N - 1)
  else:
    X_cent=X-X.mean()
    N=X_cent.shape[0]
    cov=(X_cent.T@X_cent)/(N-1)
    eigvals, eigvecs = np.linalg.eigh(cov)

# For numerical stability, add a small epsilon to eigenvalues when taking the inverse square root.
    eps = 1e-5
    D_inv_sqrt = np.diag(1.0 / np.sqrt(eigvals + eps))

# 4. Compute the whitening matrix.
    whitening_matrix = eigvecs @ D_inv_sqrt @ eigvecs.T

# 5. Transform the centered data using the whitening matrix.
    X_whitened = X_cent @ whitening_matrix

# Verification: the covariance of X_whitened should be close to the identity matrix.
    cov_whitened = (X_whitened.T @ X_whitened) / (N - 1)
    print("checkwhiten", np.diag(cov_whitened))
    X=X_whitened
  #print("Whitened covariance matrix:\n", cov_whitened)
  return X_whitened

def subspace_similarity(sigma1, sigma2, r=2):
    # Eigen-decomposition
    _, v1 = np.linalg.eigh(sigma1)
    _, v2 = np.linalg.eigh(sigma2)

    # Take top-r eigenvectors (last columns since numpy sorts ascending)
    P1 = v1[:, -r:]
    P2 = v2[:, -r:]

    # Principal angles
    angles = subspace_angles(P1, P2)  # returns array of angles (in radians)

    # Similarity metric: average cosine of principal angles
    similarity = np.mean(np.cos(angles))

    return similarity, angles