using Pkg
Pkg.develop(path="/home/video_cauasl/SymbolicRegression.jl")
Pkg.build("SymbolicRegression")
Pkg.precompile()
using DynamicExpressions, Random
using CSV, DataFrames, NPZ, SymbolicUtils, Symbolics, Distances, StatsBase, DynamicAxisWarping

import SymbolicRegression: SRRegressor, MultitargetSRRegressor
import MLJ: machine, fit!, predict, report


function evaluate_equation(equation::String, X::Vector{Float64})
    expr = Meta.parse(equation)

    y_pred = Float64[]
    for val in X
        a = val
        try
            result = Base.eval(:(let a=$val; $expr end))
            push!(y_pred, eval(result))  # Convert symbolic result to number
        catch
            push!(y_pred, NaN)
        end
    end
    return y_pred
end

function compute_dtw(y_pred::Vector{Float64}, y_target::Vector{Float64})
    # Remove NaN values if present
    valid_indices = .!isnan.(y_pred)
    y_pred_cleaned = y_pred[valid_indices]
    y_target_cleaned = y_target[valid_indices]

    # Ensure same length by truncation if necessary
    min_len = min(length(y_pred_cleaned), length(y_target_cleaned))
    y_pred_trimmed = y_pred_cleaned[1:min_len]
    y_target_trimmed = y_target_cleaned[1:min_len]

    cost, i1, i2 = dtw(y_pred_trimmed, y_target_trimmed, SqEuclidean(); transportcost = 1)

    return cost
end

function rank_equations(X::Vector{Float64}, y_target::Vector{Float64}, equations::Vector{String}, top_k::Int=5)
    distances = Float64[]
    valid_equations = String[]

    for equation in equations
        y_pred = evaluate_equation(equation, X)
        println(equation)
#         println("y_target: ", y_target, "y_pred: ", y_pred)
        rescaled_y_target = min_max_normalization(y_target, y_pred)
#         println("y_target: ", rescaled_y_target, "y_pred: ", y_pred)
        if !all(isnan.(y_pred))  # Only compute DTW if evaluation is valid
            dtw_distance = compute_dtw(y_pred, rescaled_y_target)
            println("dtw distance: ", dtw_distance)
            push!(distances, dtw_distance)
            push!(valid_equations, equation)
        end
    end

    # Sort and select top-K
    sorted_indices = sortperm(distances)
    best_equations = valid_equations[sorted_indices[1:top_k]]
    best_distances = distances[sorted_indices[1:top_k]]

    best_equations_expr = [Meta.parse(eq) for eq in best_equations]

    return best_equations_expr, best_distances
end

# function min_max_normalization(y_target::Vector{Float64}, y_pred::Vector{Float64})
#     min_val = minimum(y_target)
#     max_val = maximum(y_target)
#     return 2 .* ((y_target .- min_val) ./ (max_val - min_val)) .- 1
# end

function min_max_normalization(y_target::Vector{Float64}, y_pred::Vector{Float64})
    # Remove NaN values from y_pred before computing min/max
    valid_y_pred = filter(!isnan, y_pred)
    valid_y_target = filter(!isnan, y_target)

    # If y_pred is entirely NaN, return a zero vector to avoid meaningless normalization
    if isempty(valid_y_pred) || length(valid_y_pred) != length(y_pred)
        return y_target
    end

    min_pred = minimum(valid_y_pred)
    max_pred = maximum(valid_y_pred)

    min_targ = minimum(valid_y_target)
    max_targ = maximum(valid_y_target)

    # If min_pred and max_pred are nearly the same, avoid division by zero or extreme scaling
    tol::Float64=1e-2
    if abs(max_pred - min_pred) < tol
        return y_target
    end

    # Normalize y_target to match the range of y_pred
    return min_pred .+ ((y_target .- min_targ) ./ (max_targ - min_targ)) .* (max_pred - min_pred)
end


motion_file_path = "" # path to extracted trajectory
time_file_path = "" # path to time dimention
output_directory = "" # output folder

motion_dim = npzread(motion_file_path) # shape of [T, 2]
time_dim = npzread(time_file_path) # shape of [T,1]

# error("Stopping here for debugging")
# Check data types and dimensions
println("Motion shape: ", size(motion_dim))
println("Time shape: ", size(time_dim))

seq_len = size(time_dim, 1)

# Use time as the feature (X) and motion as the target (y)
X = (a = time_dim[:, 1],)  # Named tuple with "time" as feature

y = (
    position_x = vec(reshape(motion_dim[:, :, 1], (seq_len, 1))),  # First dimension of motion
    position_y = vec(reshape(motion_dim[:, :, 2], (seq_len, 1)))   # Second dimension of motion
)
# println(y)

x_type = eltype(X.a)
println("X data type: ", typeof(X.a))
println("y data type: ", typeof(y.position_x))

# load all existing equations
df = CSV.read("equation_bank/EquationBank_TimeSeries.csv", DataFrame)
equations = df[:, :"Modified_Formula"]


top_k = 5  # Choose how many best equations you want
best_eqs, best_dists = rank_equations(Float64.(X.a), Float64.(y.position_x), equations, top_k)


println("Top-K closest equations based on DTW distance:")
for i in 1:top_k
    println("Equation ", i, ": ", best_eqs[i], " (DTW Distance: ", best_dists[i], ")")
end

println(typeof(best_eqs[1]))

neg(x) = -x
binary_operators = [+, -, *, /, ^]
unary_operators = [cos, sin, exp, log, tan, sqrt]

operators = OperatorEnum(;
    binary_operators=binary_operators, unary_operators=unary_operators
)
variable_names = ["a"]

weight_initial_functions = 5 # weight_initial_functions/population_size=alpha
initial_functions = [
    get_tree(parse_expression((best_eqs[1]); operators=operators, variable_names=variable_names, node_type=Node{x_type})),
    get_tree(parse_expression((best_eqs[2]); operators=operators, variable_names=variable_names, node_type=Node{x_type})),
    get_tree(parse_expression((best_eqs[3]); operators=operators, variable_names=variable_names, node_type=Node{x_type})),
    get_tree(parse_expression((best_eqs[4]); operators=operators, variable_names=variable_names, node_type=Node{x_type})),
    get_tree(parse_expression((best_eqs[5]); operators=operators, variable_names=variable_names, node_type=Node{x_type})),
]

initial_functions = [func for func in initial_functions for _ in 1:weight_initial_functions]

println(initial_functions)

model = MultitargetSRRegressor(
    niterations=100,
    populations=30, 
    population_size=30,
    binary_operators=binary_operators,
    unary_operators=unary_operators,
    save_to_file=true,
    output_directory=output_directory,
    initial_functions=initial_functions,
)

println("X data type: ", typeof(X))
println("y data type: ", typeof(y))
mach = machine(model, X, y)

fit!(mach)

r = report(mach)
println(r)
