import numpy as np
import random
from sklearn.linear_model import Ridge
import h5py
import argparse
from data import load_boussinesq_data, load_rd_data

# Generate up-to-second-order polynomial terms of the input features
def generate_polynomial_library(X, feature_names, homogeneous_scaler=None):
    n_samples, n_features = X.shape
    if feature_names is None:
        feature_names = [f"x{i}" for i in range(n_features)]

    if homogeneous_scaler is not None:
        # must be array of (n_samples, 1)
        assert homogeneous_scaler.shape[0] == n_samples
        if homogeneous_scaler.ndim == 1:
            homogeneous_scaler = homogeneous_scaler[:, None]
    else:
        homogeneous_scaler = np.ones((n_samples, 1))

    library = [np.ones((n_samples, 1)) * homogeneous_scaler]  # Start with the constant term
    extended_feature_names = ["1"]

    # Add linear terms
    library.extend([X[:, i:i+1] for i in range(n_features)])
    extended_feature_names.extend(feature_names)

    # Add quadratic terms
    library.extend([X[:, i:i+1] * X[:, j:j+1] / homogeneous_scaler for i in range(n_features) for j in range(i, n_features)])
    for i in range(n_features):
        for j in range(i, n_features):
            extended_feature_names.append(f"{feature_names[i]}*{feature_names[j]}")

    # Concatenate into a single matrix
    return np.hstack(library), np.array(extended_feature_names)

# Perform sparse regression with sequential thresholding
def sparse_regression(X, y, threshold=1e-2, alpha=1e-2, max_iter=20, feature_names=None, homogeneous_scaler=None):
    """
    Perform sparse regression using ridge regression with sequential thresholding.

    Parameters:
        X (numpy array): Input features (n_samples, n_features).
        y (numpy array): Target values (n_samples,).
        threshold (float): Threshold for sequential thresholding.
        alpha (float): Regularization strength for ridge regression.
        max_iter (int): Maximum number of thresholding iterations.

    Returns:
        W (numpy array): Coefficients of the sparse model.
        Theta (numpy array): Polynomial library of input features.
    """
    Theta, extended_feature_names = generate_polynomial_library(X, feature_names, homogeneous_scaler)  # Generate the polynomial library
    W = np.zeros(Theta.shape[1])            # Initialize weights

    for _ in range(max_iter):
        # Solve ridge regression
        model = Ridge(alpha=alpha, fit_intercept=False, solver='svd')
        model.fit(Theta, y)
        W = model.coef_

        # Apply thresholding to coefficients
        W[np.abs(W) < threshold] = 0

        # Update Theta by removing small terms
        nonzero_indices = np.abs(W) >= threshold
        if np.sum(nonzero_indices) == Theta.shape[1] or np.sum(nonzero_indices) == 0:
            break_flag = True
        else:
            break_flag = False
        Theta = Theta[:, nonzero_indices]
        W = W[nonzero_indices]
        extended_feature_names = extended_feature_names[np.array(nonzero_indices)]

        # Stop if no coefficients are removed
        if break_flag:
            break
    
    eq_str = " + ".join([f"{W[i]:.3f}*{extended_feature_names[i]}" for i in range(len(W))])
    # eq_str = " + ".join([f"{W[i]}*{extended_feature_names[i]}" for i in range(len(W))])

    return W, Theta, eq_str

# Example usage
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='boussinesq')
    parser.add_argument('--data_path', type=str, default='../data/boussinesq/boussinesq_1_dt_1e-3.h5')
    parser.add_argument('--subsample', type=float, default=0.02)
    parser.add_argument('--seed', type=int, default=42)
    # dataset options
    parser.add_argument('--align_utt', action='store_true', help='align the u_tt term with u_x^2 in Boussinesq equation')
    parser.add_argument('--t_start', type=int, default=0)
    parser.add_argument('--t_end', type=int, default=200)
    args = parser.parse_args()

    np.random.seed(args.seed)
    random.seed(args.seed)

    data_path = args.data_path
    # data_path = '../data/boussinesq/boussinesq_noise_1e-4_dt_1e-3.h5'

    if args.dataset == 'boussinesq':
    
        X, I, x_names, invar_names = load_boussinesq_data(data_path, subsample=args.subsample, align_utt=args.align_utt)
        # print(X.shape, I.shape, x_names, invar_names)
        y_I = I[:, 1]  # I_0x2
        # X_I is the remaining columns
        X_I = np.delete(I, 1, axis=1)
        invar_names = invar_names[0:1] + invar_names[2:]

        # or, exclude all time derivatives
        # X_I = I[:, [4, 8, 11, 13]]
        # invar_names = [invar_names[i] for i in [4, 8, 11, 13]]

        W, Theta, eq_str = sparse_regression(X_I, y_I, threshold=0.25, alpha=0.05, feature_names=invar_names, homogeneous_scaler=X[:, 7] ** 2 if args.align_utt else None)

        # print("Sparse coefficients (W):", W)
        # print("Number of terms in the final model:", len(W))
        print("Equation: ", eq_str)

    elif args.dataset == 'reac_diff':

        I, invar_names = load_rd_data(data_path, subsample=args.subsample, t_range=[args.t_start, args.t_end])

        y1 = I[:, 0]  # I_t
        y2 = I[:, 6]  # E_t
        X_I = np.delete(I, [0, 6], axis=1)
        invar_names = invar_names[1:6] + invar_names[7:]

        W1, Theta1, eq_str1 = sparse_regression(X_I, y1, threshold=0.05, alpha=0.2, feature_names=invar_names)
        W2, Theta2, eq_str2 = sparse_regression(X_I, y2, threshold=0.05, alpha=0.2, feature_names=invar_names)

        # print("Sparse coefficients (W1):", W1)
        # print("Number of terms in the final model (I_t):", len(W1))
        print("Equation 1: I_t =", eq_str1)
        # print("Sparse coefficients (W2):", W2)
        # print("Number of terms in the final model (E_t):", len(W2))
        print("Equation 2: E_t =", eq_str2)
