using Pkg
using Roots
using Statistics, Plots, StatsPlots, Random, Distributions, StatsBase, DataFrames, CSV, LaTeXStrings, Measures

function rank(true_mean::AbstractVector, wtf::Int64)
    ranks = sortperm(true_mean, rev = true)
    return findfirst(ranks .== wtf)
end

function rank_and_gap(true_mean::AbstractVector, wtf::Int64)
    ranks = sortperm(true_mean, rev = true)
    rank = findfirst(ranks.==wtf)
    gap = maximum(true_mean)  - true_mean[wtf]
    return rank, gap
end

function rank_and_gap(true_mean::AbstractVector, wtf::AbstractFloat)
    rank = count(true_mean .>= wtf)
    gap = maximum(true_mean)  - wtf
    return rank, gap
end

function ba_diff_qc(ba_diff::AbstractVecOrMat; alpha::AbstractVector = [0.05, 0.95])
    m = zeros(size(ba_diff[1],2), length(ba_diff))
    v = copy(m)
    q1 = copy(m)
    q2 = copy(m)
    
    for i in 1:length(ba_diff)
        m[:,i] = mean(ba_diff[i], dims = 1)
        v[:,i] = std(ba_diff[i], dims = 1)
        q1[:,i] = mapslices(x -> quantile(x, alpha[1]), ba_diff[i], dims = 1)
        q2[:,i] = mapslices(x -> quantile(x, alpha[2]), ba_diff[i], dims = 1)
    end

    return m, v, q1, q2
end

get_q(m::AbstractMatrix) = mapslices(x -> [quantile(x, 0.25),quantile(x, 0.75)], m, dims = 1)'
get_m(m::AbstractMatrix) = vec(mapslices(x ->mean(x), m, dims = 1))
get_med(m::AbstractMatrix) = vec(mapslices(x ->median(x), m, dims = 1))

function compare_algorithms(budget::Int64, true_mean::Vector{Float64}, Dist::String; itnum::Int64 = 1000)
    ranks = zeros(Int, itnum,4)
    gaps =zeros(itnum,4)

    ba_diff1 = zeros(itnum, budget)
    ba_diff2 = copy(ba_diff1)
    ba_diff3 = copy(ba_diff1)
    ba_diff4 = copy(ba_diff1)

    for i in 1:itnum
        Random.seed!(i)   
        shuffle!(true_mean)
        
        best_arm_po, ba_diff1[i,:] = uniform_sampling(true_mean, budget, Dist,  print_seq = true)
        ranks[i,1], gaps[i,1] = rank_and_gap(true_mean, best_arm_po)
        
        best_arm_bucb, ba_diff2[i,:] = BUCB(0.1, true_mean, budget,Dist, print_seq = true)
        ranks[i,2], gaps[i,2] = rank_and_gap(true_mean, best_arm_bucb)
        
        best_arm_bsh, ba_diff3[i,:] = BSH(true_mean, budget,Dist)
        #ranks[i,3], gaps[i,3] = rank_and_gap(true_mean, best_arm_bsh)

        best_arm_ebt, ba_diff4[i,:] = BoxThirding(true_mean, budget,Dist; print_seq = true)
        ranks[i,4], gaps[i,4] = rank_and_gap(true_mean, best_arm_ebt)
        
    end

    return ranks, gaps, [ba_diff1, ba_diff2, ba_diff3, ba_diff4]
end


# Normalizing function
function NormalizeData(data::Vector{Float64})
    return (data .- minimum(data)) ./ (maximum(data) - minimum(data))
end

# H2 calculation
function H2(avg_reward::Vector{Float64})
    sorted_reward = sort(avg_reward, rev=true)  # Sort in reverse order (largest first)
    return maximum([i / ((sorted_reward[1] - sorted_reward[i + 1]) ^ 2) for i in 1:(length(avg_reward) - 1)])
end
function NY_dataParse(file)
    df = CSV.File(file, delim=',', header=true) |> DataFrame
    ground_truth = (df.funny .+ df.somewhat_funny) ./ df.votes 
    return ground_truth
end

colorset = [:dodgerblue :gold :green :tomato]
lineset = [:dot :dash :dashdot :solid]
labelset = ["US" "BUCB" "BSH" "B3"]
T = length(mean893)
groups = reshape(repeat(["US", "BUCB", "BSH", "B3"], inner = 1000), 1000, 4);


function draw_plot(ba_diff, ymax; legend = :none, size = (400, 400), the_title = "")
    mdiff = get_m.(ba_diff) # average simple regret of each T
    qdiff = get_q.(ba_diff) # quantiles(25%, 75%) of simple regret of each T
    pp = plot(mdiff[1], fa = 0.2, lw = 4, ylim = (-0.001, ymax),lc = colorset[1], fc = colorset[1], label = labelset[1], ribbon = qdiff[1],
        legend_columns =4, legend_size = 30, legendfontsize = 12, ylabel = "simple regret", xlabel = "T", ls = lineset[1], title = the_title)
    plot!(mdiff[2], fa = 0.2, lw = 4, lc = colorset[2], fc = colorset[2], label = labelset[2], ls = lineset[2], ribbon = qdiff[2])
    plot!(mdiff[3], fa = 0.2, lw = 4, lc = colorset[3],fc = colorset[3], label = labelset[3], ls = lineset[3], ribbon = qdiff[3])
    plot!(mdiff[4], fa = 0.2, lw = 4, lc = colorset[4], fc = colorset[4], label = labelset[4], ls = lineset[4], 
        legend = legend, size = size, ribbon = qdiff[4])
    return pp
end

