import numpy as np
from itertools import combinations_with_replacement
from sklearn.feature_selection import f_regression, r_regression, mutual_info_regression
from scipy.cluster.vq import kmeans, vq
from scipy.spatial.distance import cdist
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV
import numpy as np
from itertools import combinations_with_replacement, combinations
from scipy.special import legendre, eval_chebyt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.gaussian_process.kernels import RBF
import matplotlib.pyplot as plt
from joblib import Parallel, delayed

import sympy as sp
import numpy as np
from itertools import product
import time


def gram_schmidt_process(vectors, inner_product):
    orthogonal_vectors = []
    for v in vectors:
        w = v - sum(inner_product(v, u) * u for u in orthogonal_vectors)
        if w != 0:
            orthogonal_vectors.append(w / sp.sqrt(inner_product(w, w)))
    return orthogonal_vectors

def inner_product(p1, p2, variables):
    prod = p1 * p2
    for var in variables:
        prod = sp.integrate(prod, (var, -1, 1))
    return prod
    
# Function 1: Create all symbolic basis functions
def create_all_symbolic_basis_functions(dims=3, use_fourier=True, use_rbf=True, use_ortho=True, means=None, stds=None):
    variables = sp.symbols([f'f_{i}' for i in range(dims)])
    basis_functions = []            

    # Add first, second, and third-degree polynomial combinations
    basis_functions += [var for var in variables]   
    basis_functions += [x * y for x, y in combinations_with_replacement(variables, 2)]
    basis_functions += [x * y * z for x, y, z in combinations_with_replacement(variables, 3)]
    
    # basis_functions += [x * y * z * w for x, y, z, w in combinations_with_replacement(variables, 4)]

    if use_ortho:
        min_order = 0
        poly_order = 5
        max_terms = 3  # Maximum order of cross terms
        
        # Chebyshev polynomials
        linear_terms = []
        for var in variables:
            for n in range(3, poly_order + 1):
                linear_terms.append(sp.chebyshevt(n, var))
        basis_functions += linear_terms

        linear_terms = []
        for var in variables:
            for n in range(0, poly_order + 1):
                linear_terms.append(sp.chebyshevt(n, var))
        basis_functions += linear_terms
        
        basis_functions += [x * y for x, y in combinations_with_replacement(linear_terms, 2)]
        
    if use_rbf:
        n_rbf_centers = 10
        rbf_width = 1 / n_rbf_centers
        for var, mean, std in zip(variables, means, stds):
            centers = np.linspace(mean - 2*std, mean + 2*std, n_rbf_centers)
            for center in centers:
                basis_functions.append(sp.exp(-((var - center) ** 2) / (2 * rbf_width ** 2)))

    if use_fourier:
        n_terms = 3
        for var, mean, std in zip(variables, means, stds):
            normalized_var = (var - mean) / std  # Normalize the variable
            for i in range(1, n_terms + 1):
                basis_functions.append(sp.sin(i * sp.pi * normalized_var))
                basis_functions.append(sp.cos(i * sp.pi * normalized_var))
                
    # # # Fourier series
    if use_fourier:
        for var in variables:
            for n in range(1, n_harmonics + 1):
                basis_functions.append(sp.cos(n * var))

    # # # Gaussian RBFs
    n_rbf_centers = 10
    rbf_width = 1/n_rbf_centers
    if use_rbf:
        centers = np.linspace(-1, 1, n_rbf_centers)
        for var in variables:
            for center in centers:
                basis_functions.append(sp.exp(-((var - center) ** 2) / (2 * rbf_width ** 2)))
                
    # remove duplicates from basis functions without set
    ordered_basis_functions = ordered_set(basis_functions)

    # basis_functions = list(set(basis_functions))  # Remove duplicates
    
    print(f"    Total number of unique basis functions: {len(basis_functions)}")
    return basis_functions, variables

def ordered_set(iterable):
    return list(dict.fromkeys(iterable))

# Function 2: Select a subset of basis functions
def select_symbolic_basis_functions(all_basis_functions, indices):
    selected_basis = [all_basis_functions[i] for i in indices]
    return selected_basis

# Function 3: Compute numerical values for selected basis functions
def compute_numerical_values(selected_basis, variables, x_values):
    # Map input values to symbolic variables
    substitution_dict = {variables[i]: x_values[i] for i in range(len(variables))}
    evaluated_basis = [func.subs(substitution_dict) for func in selected_basis]
    return evaluated_basis


# Function 4: Compute numerical values for all trajectories
def compute_numerical_values_for_all_trajs(selected_basis, variables, trajs):
    """
    Computes the numerical values for all trajectories and all states.

    Parameters:
    - trajs (list of np.ndarray): A list of arrays containing input trajectories.
    - selected_basis (list): A list of symbolic expressions representing the selected basis functions.
    - variables (list): A list of symbolic variables used in the basis functions.

    Returns:
    - feats (list of np.ndarray): A list of arrays containing computed feature values for each trajectory.
    """
    feats = []

    for traj in trajs:
        num_samples = traj.shape[0]
        num_features = len(selected_basis)
        feat = np.zeros((num_samples, num_features))

        for idx, curr_point in enumerate(traj):
            # Create a substitution dictionary for the current point
            substitution_dict = {variables[i]: curr_point[i] for i in range(len(variables))}
            
            # Compute the feature values for the selected basis functions
            feat_values = [func.subs(substitution_dict) for func in selected_basis]

            # Convert symbolic values to numeric values and store them
            feat[idx, :] = np.array(feat_values, dtype=float)

        feats.append(feat)

    return feats


from sympy import lambdify
def precompile_basis_functions(selected_basis, variables):
    """
    Precompiles symbolic basis functions into numerical functions using lambdify.

    Parameters:
    - selected_basis (list): A list of symbolic expressions representing the selected basis functions.
    - variables (list): A list of symbolic variables used in the basis functions.

    Returns:
    - list: A list of precompiled numerical functions.
    """
    precompiled_functions = [lambdify(variables, func, 'numpy') for func in selected_basis]
    return precompiled_functions

def compute_numerical_values_for_all_trajs_optimized(selected_basis, variables, trajs):
    """
    Computes numerical values for all trajectories using precompiled functions for speed.

    Parameters:
    - selected_basis (list): A list of symbolic expressions representing the selected basis functions.
    - variables (list): A list of symbolic variables used in the basis functions.
    - trajs (list of np.ndarray): A list of arrays containing input trajectories.

    Returns:
    - feats (list of np.ndarray): A list of arrays containing computed feature values for each trajectory.
    """
    # Precompile the symbolic basis functions into numerical functions
    precompiled_functions = precompile_basis_functions(selected_basis, variables)

    feats = []
    for traj in trajs:
        num_samples = traj.shape[0]
        num_features = len(precompiled_functions)
        feat = np.zeros((num_samples, num_features))
        for i, func in enumerate(precompiled_functions):
            feat[:, i] = func(*traj.T)
        feats.append(feat)
    return feats



def compute_numerical_values_pre(precompiled_functions, x_values):
    """
    Computes numerical values using precompiled functions.

    Parameters:
    - precompiled_functions (list): A list of precompiled numerical functions.
    - x_values (list or array): Input values for each variable.

    Returns:
    - evaluated_basis (list): List of numerical values of each precompiled basis function.
    """
    evaluated_basis = [func(*x_values) for func in precompiled_functions]
    return evaluated_basis

import joblib
from sympy import lambdify, symbols


# Define a function to save symbolic expressions
def save_symbolic_basis_functions_joblib(selected_basis, variables, filepath='symbolic_expressions.joblib'):
    """
    Saves the symbolic basis functions and variables using joblib.

    Parameters:
    - selected_basis (list): A list of symbolic expressions representing the selected basis functions.
    - variables (list): A list of symbolic variables used in the basis functions.
    - filepath (str): Path to save the symbolic expressions and variables.

    Returns:
    - None
    """
    # Save the symbolic expressions and variables together
    data = {'basis': selected_basis, 'variables': variables}
    joblib.dump(data, filepath)

# Define a function to load symbolic expressions and recompile them
def load_and_precompile_functions_joblib(filepath):
    """
    Loads the symbolic expressions and variables and precompiles them using lambdify.

    Parameters:
    - filepath (str): Path to load the symbolic expressions and variables.

    Returns:
    - list: A list of precompiled numerical functions.
    """
    # Load the symbolic expressions and variables
    data = joblib.load(filepath)
    selected_basis = data['basis']
    variables = data['variables']

    # Precompile the functions using lambdify
    precompiled_functions = [lambdify(variables, func, 'numpy') for func in selected_basis]
    return precompiled_functions

# Define a function to compute values using the precompiled functions
def compute_with_precompiled_functions(precompiled_functions, x_values):
    """
    Computes numerical values using precompiled functions.

    Parameters:
    - precompiled_functions (list): A list of precompiled numerical functions.
    - x_values (list or array): Input values for each variable.

    Returns:
    - list: List of numerical values from the precompiled functions.
    """
    # Evaluate each precompiled function with the given input values
    evaluated_values = [func(*x_values) for func in precompiled_functions]
    return evaluated_values
