

import numpy as np
from sklearn.linear_model import LinearRegression
from utils import est_Cov, is_one_hot
import sys
import gc

class proj:
    """    
    This class computes projection matrices that remove protected information (z) from data (X).
    In some cases, this also preserves information about a target variable (y).
    
    Attributes:
        eps (float): Small constant for numerical stability (default: 1e-8)
    """
    
    def __init__(self):
        # Numerical stability threshold
        self.eps = 1e-8
        self.causal_LEACE_variant = None

    

    def fit(self, X, z, y, method='LEACE', info_type='Cov', coef=None):
        """
        Determine the P and b for the projection matrix.

        Args:
            X (np.ndarray): Data matrix of shape (n, d)
            z (np.ndarray): Protected attribute to remove, shape (n, 1)
            y (np.ndarray): Target information to preserve, shape (n, k)
            method (str): Method to use for projection (default: 'LEACE')
        
        """
        # demean X
        self.x_mean = np.mean(X, axis=0)
        X_demeaned = X - self.x_mean

        if method == 'LEACE':
            
            # get the projection matrix and bias term from LEACE
            self.P = self.LEACE(X_demeaned, z)
            self.b = (self.x_mean - self.P @ self.x_mean)

        elif method == 'opt-sep-proj':
            # get the projection matrix from optimal separation projection
            self.P = self.opt_sep_proj(X_demeaned, z, y, info_type=info_type, coef=coef)
            self.b = (self.x_mean - self.P @ self.x_mean)

        elif method == 'LEACE-no-whitening':

            # get P based on orthogonal projection
            cov_x_z = est_Cov(X_demeaned, z)
            v = cov_x_z / np.linalg.norm(cov_x_z)
            self.P = self.get_orth_proj(v)

            # get the bias term
            self.b = (self.x_mean - self.P @ self.x_mean)
        
        elif method == 'causal-LEACE':

            include_E_X_y_range = True if self.causal_LEACE_variant == 'range' else False
            self.P, self.reg_X_y, self.x_mean_tilde = self.causal_LEACE(X_demeaned, z, y, include_E_X_y_range=include_E_X_y_range)
            self.b = (self.x_mean - self.P @ self.x_mean)
        
        elif method == 'SAL':

            # get the projection matrix and bias term from SAL
            self.P = self.SAL(X_demeaned, z)
            self.b = np.zeros(X.shape[1])
        
        self.method = method

    def set_causal_LEACE_variant(self, variant):

        self.causal_LEACE_variant = variant
        
    def apply_projection(self, X, y=None):
        """
        Apply the projection matrix to the data.
        """

        # if method=='causal-LEACE', then we adapt the application based on the variant
        if self.method == 'causal-LEACE':

            # variant where we use knowledge of y - either based on an oracle or prediction
            if self.causal_LEACE_variant == 'use_y':
                if y is None:
                    sys.exit("Need to provide y for causal LEACE")
                else:

                    # we require here that y is one-hot encoded (even if binary)
                    if y.shape[1] == 1:
                         y = np.concatenate([1 - y, y], axis=1)

                    # get the residuals
                    X_cond_y = self.reg_X_y.predict(y)
                    X_tilde = X - X_cond_y
                    
                    # apply the projection
                    X_proj = X_tilde @ self.P.T + X_cond_y 

            elif self.causal_LEACE_variant == 'range':
                X_proj =  X @ self.P.T + self.b
            
        else:
            X_proj =  X @ self.P.T + self.b

        return X_proj

    def SAL(self, X, z):
        """
        Implement the spectral attribute removal (SAL) method.

        Args:
            X (np.ndarray): Data matrix of shape (n, d)
            z (np.ndarray): Protected attribute to remove, shape (n, k)

        Returns:
            np.ndarray: Projection matrix P of shape (d, d)
        """

        # Calculate the cross-covariance matrix
        Sigma = est_Cov(X, z)

        # perform SVD on Sigma
        U, _, _ = np.linalg.svd(Sigma)

        # Use the rank of Sigma to determine the number of components to keep
        k = np.linalg.matrix_rank(Sigma)

        # Get the first rank columns of U
        U_k = U[:, :k]
        
        # Get the orthogonal complement of U_k
        I_d = np.eye(X.shape[1])
        P = I_d - np.matmul(U_k, U_k.T)

        return P



    def LEACE(self, X, z):  
        """
        Determine the projection matrix using LEACE.

        Args:
            X (np.ndarray): Data matrix of shape (n, d)
            z (np.ndarray): Protected attribute to remove, shape (n, k)
        
        Returns:
            np.ndarray: Projection matrix P of shape (d, d)
        """

        # first, get the Whitening matrix
        W = self.est_W(X)

        # whiten the data
        XW = X @ W

        # check: if z.ndim > 1 and it is one-hot encoded, we can remove the last column
        if z.ndim > 1 and is_one_hot(z):
            z = z[:, :-1]
            print('Removing last column of z, as it is one-hot encoded')

        # second, get the cross-covariance with z and X
        v = est_Cov(X @ W, z)
        
        # calculate the pinv
        if v.ndim > 1:
            v_pinv = np.linalg.pinv(v)
        else:
            norm_sq = np.linalg.norm(v)**2
            v_pinv = (v / norm_sq).T

        # define the orthogonal projection matrix
        P_proj_v = v @ v_pinv
        
        # using this, define the projection matrix
        W_pinv = np.linalg.pinv(W)
        I_d = np.eye(W_pinv.shape[0])
        P = I_d - W_pinv @ P_proj_v @ W

        # delete the intermediate variables
        del W, XW, v,  v_pinv, P_proj_v
        gc.collect()

        return P

    
    def causal_LEACE(self, X, z, y, include_E_X_y_range=False):
        """
        
        Args:
            X (np.ndarray): Data matrix of shape (n, d)
            z (np.ndarray): Protected attribute to remove, shape (n, 1)
            y (np.ndarray): Target information to preserve, shape (n, k)
        
        Returns:
            np.ndarray: Projection matrix P of shape (d, d)
        """

        # we require here that y is one-hot encoded (even if binary)
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        if y.shape[1] == 1:
            y = np.concatenate([1 - y, y], axis=1)

        # first, we calculate the conditional mean of X given y for each value of y
        reg_X_y = LinearRegression().fit(y, X)
        X_cond_y = reg_X_y.predict(y)

        # second, calculate the conditional mean of y given z
        reg_y_z = LinearRegression().fit(y, z)
        z_cond_y = reg_y_z.predict(y)

        # now calculate the residuals
        X_tilde = X - X_cond_y
        z_tilde = z - z_cond_y

        # apply LEACE to X_tilde and z_tilde
        if include_E_X_y_range:
            E_X_y = reg_X_y.coef_
            Cov_X_tilde_z = est_Cov(X_tilde, z_tilde)
            v = Cov_X_tilde_z / np.linalg.norm(Cov_X_tilde_z)
            P = self.get_oblique_proj(v, E_X_y)
        else:
            P = self.LEACE(X_tilde, z_tilde)
       
        # get the mean of x_tilde (per y)
        y_unique = np.unique(y, axis=0)
        x_mean_tilde = np.zeros((y_unique.shape[0], X.shape[1]))
        i=0
        for y_val in y_unique:
            y_val_index = np.where((y == y_val).all(axis=1))[0]
            X_tilde_y = X_tilde[y_val_index]
            x_mean_tilde[i, :] = np.mean(X_tilde_y, axis=0)
            i += 1
           

        return P, reg_X_y, x_mean_tilde






    def est_W(self, X):
        """
        Estimate the whitening matrix for data matrix X.
        
        Args:
            X (np.ndarray): Data matrix of shape (n, d)
            
        Returns:
            np.ndarray: Whitening matrix of shape (d, d)
        
        Note:
            Uses eigendecomposition and handles numerical stability through thresholding
        """

        # get the covariance matrix
        Cov_X = np.cov(X, rowvar=False)

        # get the eigenvalues and eigenvectors
        eigenvals, eigenvecs = np.linalg.eigh(Cov_X)

        # for numerical stability
        eigenvals[eigenvals < self.eps] = 0

        # Take reciprocal square root of non-zero eigenvalues
        diag = np.zeros_like(eigenvals)
        nonzero = eigenvals > 0
        diag[nonzero] = 1.0 / np.sqrt(eigenvals[nonzero])
        
        # Compute whitening matrix
        W = np.dot(eigenvecs, np.dot(np.diag(diag), eigenvecs.T))

        return W
    
    def find_orthogonal_vector(self, u, v):
        """
        Find a vector orthogonal to u in the span of u and v.
        
        Given two linearly independent vectors u and v, finds a vector w that is:
        1. In the span of u and v
        2. Orthogonal to u
        3. Unit length
        
        Args:
            u (np.ndarray): First vector of shape (d, 1)
            v (np.ndarray): Second vector of shape (d, 1)
            
        Returns:
            np.ndarray: Orthogonal unit vector of shape (d, 1)
            
        Raises:
            ValueError: If either input vector is zero
            
        Note:
            Uses the construction w = au + bv where a,b are chosen to ensure w⋅u = 0
        """
        if np.allclose(u, 0) or np.allclose(v, 0):
            raise ValueError("Input vectors cannot be zero vectors")
            
        # Want w·u = 0 and w = au + bv
        # This means: (au + bv)·u = 0
        # So: a(u·u) + b(v·u) = 0
        # Let a = -(v·u), b = (u·u)
        a = -np.matmul(v.T, u)
        b = np.matmul(u.T, u)
        
        # Compute w
        w = a * u + b * v
        
        # Normalize
        w = w / np.linalg.norm(w)
        
        return w
    
    def get_orthogonal_complement_basis(self, v):
        """
        Compute orthonormal basis for the orthogonal complement of a vector.
        
        Uses QR decomposition to find d-1 vectors that form an orthonormal basis for
        the space orthogonal to v.
        
        Args:
            v (np.ndarray): Vector of shape (d, k)
            
        Returns:
            np.ndarray: Matrix of shape (d, d-k) whose columns form orthonormal basis
            
        Note:
            Handles the zero vector case by returning standard basis vectors
        """
        d = v.shape[0]
        k = v.shape[1] if v.ndim > 1 else 1

        # resphape if necessary
        if len(v.shape) == 1:
            v = v.reshape(-1, 1)
        
        # Normalize v
        u = v / np.linalg.norm(v, axis=0, keepdims=True)
        
        # Create a matrix with u as the first column and the rest as standard basis vectors
        A = np.eye(d)
        A[:, :k] = u
        
        # Perform QR decomposition
        Q, R = np.linalg.qr(A)
        
        # Return the last d-1 columns of Q, which form the orthogonal complement basis
        return Q[:, k:]
    
    def get_column_space_basis(self, A):
        """
        Get orthonormal basis for column space of matrix A using SVD.
        
        Args:
            A: Matrix of shape (d, d-1)
        Returns:
            Orthonormal basis vectors as columns of matrix
        """
        # Perform SVD
        U, s, Vh = np.linalg.svd(A, full_matrices=False)

        # Threshold for numerical stability
        tol = max(A.shape) * np.spacing(max(s))
        r = sum(s > tol)
        return U[:, :r]
    
    def opt_sep_proj(self, X, z, y, info_type='Cov', coef=None):
        """
        Compute optimal separation projection matrix.
        
        Finds a projection matrix P that:
        1. Projects parallel to the direction of z (removes this information)
        2. Preserves Covariance between X and y
        3. Maintains as much other information as possible
        
        Args:
            X (np.ndarray): Data matrix of shape (n, d)
            z (np.ndarray): Protected attribute to remove, shape (n, 1)
            y (np.ndarray): Target information to preserve, shape (n, 1)
            info_type (str): Type of information to preserve (default: 'Cov')
            coef (np.ndarray): Coefficients for linear regression (optional) of shape (d, k)
            
        Returns:
            np.ndarray: Projection matrix P of shape (d, d)
            
        Note:
            Uses whitening transformation for numerical stability
            Choice of u_type affects what information about y is preserved
        """

        # first, get the Whitening matrix
        W = self.est_W(X)

        # whiten the data
        XW = np.matmul(X, W)
        
        # check: if z.ndim > 1 and it is one-hot encoded, we can remove the last column
        if z.ndim > 1 and is_one_hot(z):
            z = z[:, :-1]
            print('Removing last column of z, as it is one-hot encoded')

        # second, get the cross-covariance with z and X
        v = est_Cov(XW, z)

            
        
        # third, use covariance to estimate the vector to be preserved
        if info_type == 'Cov':
            # get the covariance with y
            u_k = est_Cov(XW, y)
        elif info_type == 'coef':
            if coef is None:
                raise ValueError("Need to provide coef for info_type='coef'")
            u_k = np.matmul(W, coef)
        else:
            raise ValueError("info_type must be either 'Cov' or 'coef'")

        # then, get the projection matrix
        P_prime = self.get_oblique_proj(v, u_k)
        self.P_prime = P_prime
       
        # using this, define the projection matrix
        W_pinv = np.linalg.pinv(W)
        P_star = W_pinv @ P_prime @ W

        return P_star
    


    def get_orth_proj(self, v):
        """
        Compute orthogonal projection matrix projecting parallel to v.
        
        Creates a projection matrix P that:
        1. Projects orthogonally onto the space perpendicular to v
        2. Satisfies P² = P (projection property)
        3. Is symmetric
        
        Args:
            v (np.ndarray): Direction vector of shape (d, 1)
            
        Returns:
            np.ndarray: Projection matrix of shape (d, d)
        """

        # check if the vector is unit
        if np.abs(np.linalg.norm(v) - 1) > self.eps:
            v = v/np.linalg.norm(v)

        # get the dimension of the vector
        d = v.shape[0]

        # get d x (d-1) matrix whose columns are orthonormal basis of the set of all vectors orthogonal to v
        I_d = np.eye(d)
        
        # get the projection matrix
        P = I_d - np.matmul(v, v.T)

        return P


    # Function to orthogonalize a candidate vector against the current basis
    def orthogonalize(self, vec, basis_list):
        for b in basis_list:
            vec = vec - np.dot(b, vec) * b
        return vec
        



    def get_oblique_proj(self, v, u_k):
        """
        Compute oblique projection matrix with specific properties.
        
        Creates a projection matrix P that:
        1. Projects parallel to v
        2. Ensures each column vector in u_k is in the range of the projection matrix
        3. Satisfies P² = P (projection property)
        
        Args:
            v (np.ndarray): Direction to project parallel to, shape (d, k)
            u_k (np.ndarray): Matrix of directions to preserve, shape (d, k)
            
        Returns:
            np.ndarray: Oblique projection matrix of shape (d, d)
        """
        # Normalize v
        if np.abs(np.linalg.norm(v, axis=0) - 1).max() > self.eps:
            v = v / np.linalg.norm(v, axis=0, keepdims=True)
        
        # Normalize each column of u_k
        if np.abs(np.linalg.norm(u_k, axis=0) - 1).max() > self.eps:
            u_k = u_k / np.linalg.norm(u_k, axis=0, keepdims=True)
        
        # Get orthogonal complement basis of v
        V_perp_v = self.get_orthogonal_complement_basis(v)

        # Orthonormalize u_k via QR
        U_fixed, _ = np.linalg.qr(u_k)

        # compute k additional vectors from the span{u_1,...,u_k,v_1,...,v_k}
        # that is orthogonal to all of u_1, ..., u_k

        # Compute additional vector from the span{u_1,...,u_k,v}
        #U_list = [U_fixed[:, i].reshape(-1,1) for i in range(U_fixed.shape[1])]
        u_perp = self.find_orthogonal_vector_multi(U_fixed, v)

        # get an orthonormal basis of the set of all vectors orthogonal to u_perp
        V_perp_u_perp =self.get_orthogonal_complement_basis(u_perp)
     
        # get the projection matrix - first define the matrix to be inverted
        A = np.matmul(V_perp_v.T, V_perp_u_perp)
        A_inv = np.linalg.pinv(A)

        # get the projection matrix
        P = np.matmul(np.matmul(V_perp_u_perp, A_inv), V_perp_v.T)

        return P
    

    # def find_orthogonal_vector_multi(self, u_list, v):
    #     """
    #     Given a list of vectors (each as a (d,1) numpy array), return a vector w that is in the span of v and 
    #     orthogonal to all vectors in u_list.
    #     """
    #     # If no fixed vectors are provided, return v as is.
    #     if len(u_list) == 0:
    #         return v
    #     # Form a matrix whose columns are the u_list vectors
    #     U = np.column_stack(u_list)
    #     # Orthonormalize U via QR
    #     Q, _ = np.linalg.qr(U)
    #     # Project v onto the span of Q
    #     proj = Q @ (Q.T @ v)
    #     w = v - proj
    #     return w
    
    def find_orthogonal_vector_multi(self, U, v):
        """
        Given a list of vectors u_list and matrix v, returns k vectors that are:
        1. In the span of {u_1,...,u_k,v_1,...,v_k}
        2. Strictly orthogonal to all vectors in u_list
        3. Orthonormal to each other
        
        Args:
            u_list (list): List of vectors each of shape (d,1) representing u_1,...,u_k
            v (np.ndarray): Matrix of shape (d,k) representing v_1,...,v_k
            
        Returns:
            np.ndarray: Matrix of shape (d,k) whose columns form orthonormal vectors
        """
       
            
        # Get orthonormal basis for span of U
        Q_u, _ = np.linalg.qr(U)
        
        # Form matrix of all vectors [U V]
        UV = np.column_stack([U, v])
        
        # Get orthonormal basis for combined span
        Q_uv, _ = np.linalg.qr(UV)
        
        # For each vector in Q_uv, project out Q_u components and normalize
        k = v.shape[1]  # number of vectors we need
        result = []

        # turn U into a list of vectors
        u_list = [U[:, i:i+1] for i in range(U.shape[1])]
        
        for i in range(Q_uv.shape[1]):
            q = Q_uv[:, i:i+1]  # keep as matrix
            
            # Project out all components in U direction
            for u in u_list:
                q = q - (u.T @ q) * u
                
            # Check if vector is non-zero after projection
            norm = np.linalg.norm(q)
            if norm > self.eps:
                q = q / norm
                
                # Also ensure it's orthogonal to previously found vectors
                for r in result:
                    q = q - (r.T @ q) * r
                
                # If still non-zero after all projections, keep it    
                if np.linalg.norm(q) > self.eps:
                    result.append(q)
                    
            if len(result) == k:
                break
                
        if len(result) < k:
            raise ValueError(f"Could only find {len(result)} orthogonal vectors instead of requested {k}")
            
        return np.column_stack(result)
   






