module Analysis

using DataFrames
using CSV
using Statistics
using Random
using Dates
using Plots
using MCTS
using MCTS: DPWSolver

# Set plotting backend
Plots.gr()
ENV["GKSwstype"] = "100"  # Non-interactive backend

# Note: Configuration should be loaded by run.jl before including this module
# The config will be available as Main.Config

# Import modules
using IBMDPDesigns.GenerativeDesigns: Evidence, QuadraticDistance, DistanceBased, Exponential, Uniform, Variance
using IBMDPDesigns.GenerativeDesigns: perform_ensemble_designs, process_ensemble_results_enhanced
using IBMDPDesigns.GenerativeDesigns: find_max_likelihood_action_sets_with_utility
using IBMDPDesigns.GenerativeDesigns: plot_multiple_max_likelihood_action_sets, plot_ensemble_pareto
using IBMDPDesigns.GenerativeDesigns: validate_ensemble_frequencies

# Random.seed!(42)

# ============= EXPORTS =============
export AnalysisSetup
export setup_analysis, run_single_sample, run_all_samples
export get_representative_samples, generate_summary

# ============= ANALYSIS SETUP STRUCTURE =============
"""
    AnalysisSetup

Container for all analysis components after initialization.
"""
struct AnalysisSetup
    data::DataFrame
    data_for_sampler::DataFrame
    sampler::Function
    uncertainty::Function
    weights::Any  # Can be Function or Dict
    solver::DPWSolver
    features::Vector{String}
    target::String
    assays::Dict
    config::NamedTuple
    results_base::String
end

# ============= SETUP FUNCTIONS =============
"""
    setup_analysis(data_path::String)

Initialize all components needed for analysis.
"""
function setup_analysis(data_path::String)
    println("\n⚙️ Setting up analysis...")
    
    # Validate configuration
    Base.invokelatest(Main.Config.validate_config)
    
    # Get full configuration
    config = Base.invokelatest(Main.Config.get_full_config)
    
    # Load data
    data = DataFrame(CSV.File(data_path))
    println("  ✓ Loaded $(nrow(data)) samples with $(ncol(data)) columns")
    
    # Extract configuration
    features = config.dataset.feature_columns
    target = config.dataset.target_column
    
    # Get experimental columns from assays (needed for sampling)
    exp_columns = collect(keys(config.experiment.assays))
    
    # Prepare data for sampler - include features, target, and experimental columns
    all_columns = unique(vcat(features, [target], exp_columns))
    data_for_sampler = select(data, all_columns)
    
    # Setup distances for each feature
    distances = setup_distances(features, config)
    
    # Create sampler using DistanceBased
    println("  Setting up sampler...")
    
    # Setup similarity model (prior)
    similarity = if config.ceed.uncertainty_config.prior_type == "exponential"
        Exponential(λ = 1.0)
    else
        Uniform()
    end
    
    # Setup uncertainty model
    uncertainty_model = Variance()
    
    # Create sampler
    result = DistanceBased(
        data_for_sampler;
        target = target,
        uncertainty = uncertainty_model,
        similarity = similarity,
        distance = distances
    )
    
    sampler = result.sampler
    uncertainty = result.uncertainty
    weights = result.weights  # It's a function, not a Dict
    
    # Setup solver
    println("  Setting up solver...")
    solver = DPWSolver(;
        n_iterations = config.ceed.solver_config.max_iterations,
        exploration_constant = config.ceed.solver_config.exploration_constant,
        depth = config.ceed.solver_config.depth,
        tree_in_info = config.ceed.solver_config.tree_in_info,
        keep_tree = config.ceed.solver_config.keep_tree
    )
    
    # Setup assays
    assays = setup_assays(config.experiment.assays)
    
    # Set results base directory
    results_base = config.output.results_dir
    mkpath(results_base)
    
    return AnalysisSetup(
        data,
        data_for_sampler,
        sampler,
        uncertainty,
        weights,
        solver,
        features,
        target,
        assays,
        config,
        results_base
    )
end

"""
    setup_distances(features::Vector{String}, config::NamedTuple)

Setup distance metrics for each feature.
"""
function setup_distances(features::Vector{String}, config::NamedTuple)
    distances = Dict{String, Any}()
    
    for feature in features
        if config.ceed.uncertainty_config.distance_type == "quadratic"
            distances[feature] = QuadraticDistance(λ = 0.1)
        else
            error("Only 'quadratic' distance type is currently supported")
        end
    end
    
    return distances
end

"""
    setup_assays(assay_config::Dict)

Convert assay configuration to the format expected by CEEDesigns.
"""
function setup_assays(assay_config::Dict)
    # Already in the correct format from config
    return assay_config
end

# ============= SAMPLE SELECTION =============
"""
    get_representative_samples(setup::AnalysisSetup)

Get representative samples based on scenarios or return all samples.
"""
function get_representative_samples(setup::AnalysisSetup)
    data = setup.data
    config = setup.config
    
    if config.scenarios.use_scenarios && !isnothing(config.scenarios.get_scenario)
        # Group by scenario
        scenarios = String[]
        for row in eachrow(data)
            push!(scenarios, config.scenarios.get_scenario(row))
        end
        data[!, :scenario] = scenarios
        
        # Select representative from each scenario
        representatives = DataFrame()
        for scenario in unique(scenarios)
            scenario_data = filter(row -> row.scenario == scenario, data)
            if nrow(scenario_data) > 0
                # Select median sample based on target
                median_idx = argmin(abs.(scenario_data[:, setup.target] .- median(scenario_data[:, setup.target])))
                push!(representatives, scenario_data[median_idx, :])
            end
        end
        
        println("  ✓ Found $(nrow(representatives)) representative samples across $(length(unique(scenarios))) scenarios")
        return representatives
    else
        # Return all samples
        return data
    end
end

# ============= MAIN ANALYSIS FUNCTIONS =============
"""
    run_single_sample(setup::AnalysisSetup, sample::DataFrameRow, sample_name::String; kwargs...)

Run analysis for a single sample.
"""
function run_single_sample(
    setup::AnalysisSetup, 
    sample::DataFrameRow, 
    sample_name::AbstractString;
    ensemble_size::Union{Int,Nothing} = nothing,
    save_plots::Bool = true,
    tau_values::Vector{Float64} = [0.9]
)
    config = setup.config
    
    println("\n📁 Processing $sample_name")
    
    # Create sample directory
    sample_dir = joinpath(setup.results_base, sample_name)
    mkpath(sample_dir)
    
    # Extract initial state and create Evidence object
    initial_state = Dict{String, Float64}()
    evidence_pairs = []
    for feature in setup.features
        initial_state[feature] = sample[feature]
        push!(evidence_pairs, feature => sample[feature])
    end
    
    # Create Evidence object with pairs
    state_init = Evidence(evidence_pairs...)
    
    # Calculate initial uncertainty
    initial_uncertainty = setup.uncertainty(state_init)
    
    println("  Initial state:")
    for (k, v) in initial_state
        println("    $k = $(round(v, digits=3))")
    end
    println("  Initial uncertainty: $(round(initial_uncertainty, digits=3))")
    
    # Save sample metadata
    metadata_df = DataFrame(
        sample_name = [sample_name],
        target_initial = [sample[setup.target]],
        initial_uncertainty = [initial_uncertainty]
    )
    for feature in setup.features
        metadata_df[!, Symbol("initial_", feature)] = [sample[feature]]
    end
    CSV.write(joinpath(sample_dir, "sample_metadata.csv"), metadata_df)
    println("  📋 Saved sample metadata")
    
    # Setup terminal condition
    terminal_condition = (
        Dict(setup.target => [config.rop.target_range.min, config.rop.target_range.max]),
        0.5  # Default conditional weights threshold
    )
    
    # Run ensemble for each tau value
    ensemble_size = isnothing(ensemble_size) ? config.ceed.default_ensemble_size : ensemble_size
    results_by_tau = Dict()
    
    for tau in tau_values
        println("  Processing τ = $tau")
        
        # Run ensemble analysis
        ensemble_results = perform_ensemble_designs(
            setup.assays;  # Experiments/assays
            sampler = setup.sampler,
            uncertainty = setup.uncertainty,
            thresholds = 11,  # Number of threshold levels
            evidence = state_init,
            weights = setup.weights,
            data = setup.data_for_sampler,
            terminal_condition = terminal_condition,
            realized_uncertainty = true,
            solver = setup.solver,
            repetitions = 0,
            mdp_options = (
                conditional_constraints_enabled = true,
                max_parallel = config.experiment.max_parallel_assays,
                costs_tradeoff = (1.0, 0.0)  # Money-biased by default (MUST be Tuple, not Vector)
            ),
            N = ensemble_size,
            thred_set = [tau]
        )
        
        # Process results
        # try
            df, plt_hist, ensemble_front = process_ensemble_results_enhanced(
                ensemble_results, tau;
                save_dir = sample_dir,
                save_plots = save_plots
            )
            
            # Save the DataFrame as CSV
            csv_filename = joinpath(sample_dir, "ensemble_results_tau_$(tau).csv")
            CSV.write(csv_filename, df)
            println("    📄 Saved CSV: ensemble_results_tau_$(tau).csv")
            
            # # Validate ensemble frequencies
            # is_valid = validate_ensemble_frequencies(df)
            # if !is_valid
            #     println("    ⚠️ Warning: Ensemble frequencies validation failed")
            # end
            
            # Store results
            results_by_tau[tau] = (
                df = df,
                plot_hist = plt_hist,
                ensemble_front = ensemble_front,
                ensemble_results = ensemble_results
            )
            
            # Find and plot max likelihood action sets
            if save_plots && nrow(df) > 0
                # max_like_df = find_max_likelihood_action_sets_with_utility(
                #     df, setup.assays;
                #     utility_values = [1.0, 5.0, 10.0, 20.0, 50.0, 100.0],
                #     tau = tau
                # )
                # max_like_df = find_max_likelihood_action_sets_with_utility(df)
                
                # if nrow(max_like_df) > 0
                #     plot_multiple_max_likelihood_action_sets(
                #         max_like_df;
                #         tau = tau,
                #         save_path = sample_dir
                #     )
                #     println("    📊 Saved max likelihood action sets plot")
                # end
                
                # Plot Pareto front
                # plot_ensemble_pareto(
                #     df, tau;
                #     save_path = sample_dir
                # )
                # println("    📊 Saved ensemble Pareto front plot")
            end
            
        # catch e
        #     println("    ❌ Error processing ensemble results: $e")
        #     results_by_tau[tau] = nothing
        # end
    end
    
    println("  ✅ Completed $sample_name")
    
    return results_by_tau
end

"""
    run_all_samples(setup::AnalysisSetup; kwargs...)

Run analysis for all representative samples.
"""
function run_all_samples(
    setup::AnalysisSetup;
    ensemble_size::Union{Int,Nothing} = nothing,
    save_plots::Bool = true,
    tau_values::Vector{Float64} = [0.9]
)
    representatives = get_representative_samples(setup)
    all_results = Dict()
    
    for (idx, sample) in enumerate(eachrow(representatives))
        sample_name = if hasproperty(sample, :scenario)
            sample.scenario
        else
            "Sample_$(idx)"
        end
        
        println("\n📊 Processing sample $idx/$(nrow(representatives)): $sample_name")
        
        results = run_single_sample(
            setup, sample, sample_name;
            ensemble_size = ensemble_size,
            save_plots = save_plots,
            tau_values = tau_values
        )
        
        all_results[sample_name] = results
    end
    
    # Generate summary
    generate_summary(all_results, setup; ensemble_size = ensemble_size)
    
    return all_results
end

"""
    generate_summary(all_results::Dict, setup::AnalysisSetup; kwargs...)

Generate summary statistics across all samples.
"""
function generate_summary(all_results::Dict, setup::AnalysisSetup; ensemble_size::Int = 10)
    println("\n📊 Generating analysis summary...")
    
    summary_data = DataFrame()
    
    for (sample_name, results_by_tau) in all_results
        if !isnothing(results_by_tau)
            for (tau, results) in results_by_tau
                if !isnothing(results) && haskey(results, :df)
                    df = results.df
                    if nrow(df) > 0
                        # Calculate summary statistics
                        summary_row = DataFrame(
                            sample = [sample_name],
                            tau = [tau],
                            ensemble_size = [ensemble_size],
                            mean_cost = [mean(df.Cost)],
                            std_cost = [std(df.Cost)],
                            mean_utility = [mean(df.Utility)],
                            std_utility = [std(df.Utility)],
                            unique_actions = [length(unique(df.ActionTaken))],
                            convergence_rate = [sum(df.Converged) / nrow(df)]
                        )
                        append!(summary_data, summary_row)
                    end
                end
            end
        end
    end
    
    if nrow(summary_data) > 0
        # Save summary
        summary_path = joinpath(setup.results_base, "analysis_summary.csv")
        CSV.write(summary_path, summary_data)
        println("  📄 Saved summary to: analysis_summary.csv")
        
        # Print summary statistics
        println("\n📈 Summary Statistics:")
        println("  Total samples analyzed: $(length(unique(summary_data.sample)))")
        println("  Average cost: $(round(mean(summary_data.mean_cost), digits=2)) ± $(round(mean(summary_data.std_cost), digits=2))")
        println("  Average utility: $(round(mean(summary_data.mean_utility), digits=3)) ± $(round(mean(summary_data.std_utility), digits=3))")
        println("  Average convergence rate: $(round(mean(summary_data.convergence_rate) * 100, digits=1))%")
    else
        println("  ⚠️ No valid results to summarize")
    end
    
    return summary_data
end

end # module Analysis