import pytest
import numpy as np
import torch
from proj import proj
from concept_erasure import LeaceEraser
from utils import est_Cov

@pytest.fixture
def sample_data():
    # Create sample data with known properties
    np.random.seed(42)
    n, p = 100, 10
    X = np.random.randn(n, p)
    z = np.random.binomial(1, 0.5, size=(n, 1))
    y = np.random.binomial(1, 0.5, size=(n, 1))
    return X, z, y

@pytest.fixture
def sample_orthogonal_data():
    n, p = 100, 2

    # 1. Generate X ~ N(0, I) with n samples in R^p.
    X = np.random.randn(n, p)
    
    # 2. Center X (for numerical precision).
    X_centered = X - np.mean(X, axis=0)
    
    # 3. Compute the sample covariance of X_centered.
    S = (X_centered.T @ X_centered) / n
    
    # 4. Whiten X: compute S^{-1/2} via eigen-decomposition.
    eigvals, eigvecs = np.linalg.eigh(S)
    eps = 1e-10  # avoid division by zero issues
    eigvals = np.clip(eigvals, eps, None)
    S_inv_sqrt = eigvecs @ np.diag(1.0 / np.sqrt(eigvals)) @ eigvecs.T
    
    # Now, X_white has exactly identity as its sample covariance:
    X_white = X_centered @ S_inv_sqrt

    # 5. Generate transformation matrices A and B such that A^T B = 0.
    
    # For the univariate case, A and B are p-dimensional vectors.
    a = np.random.randn(p)
    b = np.random.randn(p)
    # Orthogonalize b relative to a.
    b = b - (np.dot(a, b) / np.dot(a, a)) * a
    # Normalize so that both have unit norm.
    a = a / np.linalg.norm(a)
    b = b / np.linalg.norm(b)
    # Reshape to (p, 1)
    A = a.reshape(p, 1)
    B = b.reshape(p, 1)
   
    # 6. Construct y and z.
    y = X_white @ A  # shape: (n, m)
    z = X_white @ B  # shape: (n, m)
    
    return X_white, y, z
    

def test_est_Cov(sample_data):
    X, z, _ = sample_data
    proj_obj = proj()
    Cov = est_Cov(X, z)
    assert Cov.shape == (X.shape[1], z.shape[1])
    
    # Compare with numpy's covariance
    X_centered = X - np.mean(X, axis=0)
    z_centered = z - np.mean(z, axis=0)
    Cov_np = np.dot(X_centered.T, z_centered) / (X.shape[0] - 1)
    
    np.testing.assert_almost_equal(Cov, Cov_np, decimal=5)

def test_est_W(sample_data):
    X, _, _ = sample_data
    proj_obj = proj()
    W = proj_obj.est_W(X)
    assert W.shape == (X.shape[1], X.shape[1])
    
    # Test whitening property
    XW = X @ W
    Cov = np.cov(XW, rowvar=False)
    np.testing.assert_almost_equal(Cov, np.eye(Cov.shape[0]), decimal=5)

def test_LEACE(sample_data):
    X, z, _ = sample_data
    proj_obj = proj()
    
    # Fit and apply projection
    proj_obj.fit(X, z, None, method='LEACE')
    X_proj = proj_obj.apply_projection(X)
    
    # Test that projected data has zero correlation with z
    cov_proj = est_Cov(X_proj, z)
    
    # check if similar to zero vec
    zero_vec = np.zeros(cov_proj.shape)
    np.testing.assert_almost_equal(cov_proj, zero_vec, decimal=5)

    # check if it all works if we add a mean to X
    mean_X = np.random.uniform(-1, 1, size=1)
    X = X + mean_X
    X_proj = proj_obj.apply_projection(X)
    cov_proj = est_Cov(X_proj, z)
    np.testing.assert_almost_equal(cov_proj, zero_vec, decimal=5)

def test_opt_sep_proj(sample_data):
    X, z, y = sample_data
    proj_obj = proj()
    
    # Fit and apply projection
    proj_obj.fit(X, z, y, method='opt-sep-proj')
    X_proj = proj_obj.apply_projection(X)
    
    # Test that projected data has zero correlation with z
    cov_proj = est_Cov(X_proj, z)
    zero_vec = np.zeros(cov_proj.shape)
    np.testing.assert_almost_equal(cov_proj, zero_vec, decimal=5)
    
    # Test that covariance with y is preserved
    cov_orig = est_Cov(X, y)
    cov_proj = est_Cov(X_proj, y)
    np.testing.assert_almost_equal(cov_orig, cov_proj, decimal=5)

    # check if it all works if we add a mean to X
    mean_X = np.random.uniform(-1, 1, size=1)
    X = X + mean_X
    X_proj = proj_obj.apply_projection(X)

    # Test that projected data has zero correlation with z
    cov_proj = est_Cov(X_proj, z)
    zero_vec = np.zeros(cov_proj.shape)
    np.testing.assert_almost_equal(cov_proj, zero_vec, decimal=5)


def test_get_orth_proj_obj():
    # Test with unit vector
    v = np.array([[1], [0], [0]])
    proj_obj = proj()
    P = proj_obj.get_orth_proj(v)
    
    # Test projection properties
    assert P.shape == (3, 3)
    PP = P @ P
    np.testing.assert_almost_equal(PP, P, decimal=5)
    
    # Test that P projects orthogonal to v
    Pv = P @ v
    np.testing.assert_almost_equal(np.linalg.norm(Pv), 0, decimal=5)

def test_get_orthogonal_complement_basis():
    v = np.array([[1], [0], [0]])
    proj_obj = proj()
    V = proj_obj.get_orthogonal_complement_basis(v)
    
    # Test shape
    assert V.shape == (3, 2)
    
    # Test orthogonality to v
    np.testing.assert_almost_equal(v.T @ V, 0, decimal=5)
    
    # Test orthonormality of columns
    VTV = V.T @ V
    np.testing.assert_almost_equal(VTV, np.eye(2), decimal=5)

def test_LEACE_package(sample_data):
    """Test LEACE with data that has non-zero mean"""
    # get the sample data
    X, z, _ = sample_data
    
    # add a random mean to the data
    mean_X = np.random.uniform(-1, 1, size=1)
    X = X + mean_X

    # Fit and apply projection
    proj_obj = proj()
    proj_obj.fit(X, z, None, method='LEACE')
    X_proj = proj_obj.apply_projection(X)
    cov_proj = est_Cov(X_proj, z)
    
    # get the LeaceFitter
    X_t = torch.from_numpy(X)
    z_t = torch.from_numpy(z)
    eraser = LeaceEraser.fit(X_t, z_t)

    # apply the leace fitter to X, calculate the cov
    X_leace_orig = eraser(X_t).numpy()
    cov_leace_orig = est_Cov(X_leace_orig, z)

    # check if cov_leace_orig is similar to cov_proj
    np.testing.assert_almost_equal(cov_leace_orig, cov_proj, decimal=5)

    # check if the mean is similar
    np.testing.assert_almost_equal(X_proj.mean(), X_leace_orig.mean(), decimal=5)


def test_LEACE_equals_opt_sep_orthogonal(sample_orthogonal_data):
    """Test that LEACE equals opt-sep-proj when covariances are orthogonal"""
    X, z, y = sample_orthogonal_data
    proj_obj = proj()

    # add a mean to X
    mean_X = np.random.uniform(-1, 1, size=1)
    X = X + mean_X
    
    # Verify orthogonality of covariances
    cov_x_y = est_Cov(X, y)
    cov_x_z = est_Cov(X, z)
    inprod = cov_x_y.T @ cov_x_z
    assert inprod < 1e-5, "Covariances are not orthogonal"
    
    # Apply both projections
    proj_obj.fit(X, z, None, method='LEACE')
    X_proj_leace = proj_obj.apply_projection(X)
    
    proj_obj.fit(X, z, y, method='opt-sep-proj')
    X_proj_opt = proj_obj.apply_projection(X)
    
    # Check z covariance is removed
    cov_z_leace = est_Cov(X_proj_leace, z)
    cov_z_opt = est_Cov(X_proj_opt, z)
    np.testing.assert_almost_equal(cov_z_leace, 0, decimal=5)
    np.testing.assert_almost_equal(cov_z_opt, 0, decimal=5)
    
    # Check y covariance is preserved
    cov_y_orig = est_Cov(X, y)
    cov_y_leace = est_Cov(X_proj_leace, y)
    cov_y_opt = est_Cov(X_proj_opt, y)
    np.testing.assert_almost_equal(cov_y_leace, cov_y_orig, decimal=5)
    np.testing.assert_almost_equal(cov_y_opt, cov_y_orig, decimal=5)
    
    # Check if projections are equivalent
    np.testing.assert_almost_equal(X_proj_leace, X_proj_opt, decimal=5)


def test_find_orthogonal_vector_multi_empty():
    # When no fixed vectors are provided, the function should return v as is.
    projector = proj()
    v = np.array([[3], [4]])
    result = projector.find_orthogonal_vector_multi([], v)
    np.testing.assert_allclose(result, v, atol=1e-8)

def test_find_orthogonal_vector_multi_single():
    # With one fixed vector, the result should be orthogonal to that vector.
    projector = proj()
    u = np.array([[1], [0]])
    v = np.array([[1], [1]])
    result = projector.find_orthogonal_vector_multi([u], v)
    dot_val = np.dot(u.T, result)
    np.testing.assert_allclose(dot_val, [[0]], atol=1e-8)
    # Also check that the result is not the zero vector.
    assert np.linalg.norm(result) > 1e-8

def test_find_orthogonal_vector_multi_multiple():
    # With two fixed vectors in R^3, the orthogonal component of v should be orthogonal to both.
    projector = proj()
    u1 = np.array([[1], [0], [0]])
    u2 = np.array([[0], [1], [0]])
    v = np.array([[1], [1], [1]])
    result = projector.find_orthogonal_vector_multi([u1, u2], v)
    dot1 = np.dot(u1.T, result)
    dot2 = np.dot(u2.T, result)
    np.testing.assert_allclose(dot1, [[0]], atol=1e-8)
    np.testing.assert_allclose(dot2, [[0]], atol=1e-8)
    # Since span{u1, u2} is the x-y plane, the remaining component should be along z.
    expected = np.array([[0], [0], [1]])
    # Allow for sign differences.
    np.testing.assert_allclose(np.abs(result.flatten()), np.abs(expected.flatten()), atol=1e-8)
