import numpy as np
import pandas as pd

def load_theta_and_ids(theta_path, ids_path):
    """Loads theta vectors and corresponding IDs, flattening each theta to 1D."""
    thetas = np.load(theta_path)
    ids = np.load(ids_path)
    # Flatten each theta (N, 3, 20) -> (N, 60)
    thetas = thetas.reshape(thetas.shape[0], -1)
    # Convert to flat string if IDs are bytes
    if ids.dtype.type is np.bytes_:
        ids = np.array([x.decode() for x in ids])
    return thetas, ids


def cosine_similarity_matrix(A, B):
    """
    Compute cosine similarities between all vectors in A (N, D) and B (M, D).
    Returns (M, N) matrix where entry (i, j) = cos_sim(B[i], A[j])
    """
    # Compute dot products for all pairs
    dot_products = np.dot(B, A.T)  # (M, N)
    # Compute norms
    B_norms = np.linalg.norm(B, axis=1, keepdims=True)    # (M, 1)
    A_norms = np.linalg.norm(A, axis=1, keepdims=True)    # (N, 1)
    # Outer product of norms: (M, 1) * (1, N) -> (M, N)
    norm_matrix = B_norms @ A_norms.T + 1e-8
    # Elementwise division
    return dot_products / norm_matrix


def sgd_test_with_ids(
    ref_theta_path, ref_ids_path,
    test_theta_path, test_ids_path,
    reference_avg,
    conf_interval,
    out_csv=None,
    verbose=True
):
    # 1. Load data
    A, A_ids = load_theta_and_ids(ref_theta_path, ref_ids_path)
    B, B_ids = load_theta_and_ids(test_theta_path, test_ids_path)
    if verbose:
        print(f"Reference shape: {A.shape}, Test shape: {B.shape}")

    # 2. Cosine similarity matrix
    sim_matrix = cosine_similarity_matrix(A, B)
    # 3. Row-wise average for each test image
    rho_avg = np.mean(sim_matrix, axis=1)
    
    # 4. Subtract reference average
    rho_final = rho_avg - reference_avg
    # 5. Interval check
    min_bound, max_bound = conf_interval
    inside_mask = (rho_final >= min_bound) & (rho_final <= max_bound)
    percent_inside = 100.0 * inside_mask.sum() / len(rho_final)
    # 6. Assign points (1 = inside, 0 = outside)
    points = inside_mask.astype(int)

    # 7. Tabulate results with IDs and points
    results_df = pd.DataFrame({
        "image_id": B_ids,
        "rho_avg": rho_avg,
        "rho_final": rho_final,
        "inside_interval": inside_mask,
        "point": points
    })
    if verbose:
        print(results_df.head())
        print(f"\nInterval: [{min_bound:.8f}, {max_bound:.8f}] | Inside: {inside_mask.sum()} / {len(rho_final)} ({percent_inside:.2f}%)")
    if out_csv is not None:
        results_df.to_csv(out_csv, index=False)
        print(f"Results saved to {out_csv}")
    return results_df

# -------------------- USAGE EXAMPLE --------------------
if __name__ == "__main__":
    # Messidor 1 example
    sgd_test_with_ids(
        ref_theta_path="/drive2/Kuntal/Pysindy-experiment/M1-output/messidor_all_thetas.npy",
        ref_ids_path="/drive2/Kuntal/Pysindy-experiment/M1-output/messidor_theta_ids.npy",
        test_theta_path="/drive2/Kuntal/Pysindy-experiment/eyepacs_theta_data/eyepacs_all_thetas.npy",     # <-- test domain
        test_ids_path="/drive2/Kuntal/Pysindy-experiment/eyepacs_theta_data/eyepacs_all_thetas.npy",        # <-- test ids
        reference_avg=0.04047960306748797,
        conf_interval=(-0.03824305970024752, 0.03824305970024752),
        out_csv="/drive2/Kuntal/Pysindy-experiment/pysindy/SDG/sgd_aptos_results.csv",
        verbose=True
    )
