using Pkg
using Roots
using Statistics, Plots, StatsPlots, Random, Distributions, StatsBase, DataFrames, CSV, LaTeXStrings, Measures


mutable struct Arm
    avg_reward
    num_pulls::Int
    eprc_reward::Float64

    function Arm(avg_reward)
        return new(avg_reward, 0, 0.0)
    end
end

function pull_arm!(Dist::String, arm::Int64, pull_times::Int64,marms::AbstractVector, npulls::AbstractVector, true_mean::AbstractVector)
    npulls[arm] += pull_times
    ta = true_mean[arm]
    if Dist == "Normal"
        reward = sum(rand(Normal(ta, 0.5^2), pull_times))
    elseif Dist == "Bernoulli"
        reward = sum(rand(Bernoulli(ta), pull_times))
    elseif Dist == "Poisson"
        reward = sum(rand(Poisson(ta), pull_times))
    elseif Dist == "Power"
        reward = sum(rand(Kumaraswamy(1/(1/ta - 1), 1),pull_times))
    elseif Dist == "Exponential"
        reward = sum(rand(Exponential(ta), pull_times))
    elseif Dist == "Deterministic"
        reward = ta*pull_times
    end
    marms[arm] += reward
    return marms, npulls
end

function pull_arm!(Dist::String, arms::Vector{Int64}, pull_times::Int64,marms::AbstractVector, npulls::AbstractVector, true_mean::AbstractVector)
    npulls[arms] .+= pull_times
    for arm in arms
        ta = true_mean[arm]
        if Dist == "Normal"
            reward = sum(rand(Normal(ta, 0.5^2), pull_times))
        elseif Dist == "Bernoulli"
            reward = sum(rand(Bernoulli(ta), pull_times))
        elseif Dist == "Poisson"
            reward = sum(rand(Poisson(ta), pull_times))
        elseif Dist == "Power"
            reward = sum(rand(Kumaraswamy(1/(1/ta - 1), 1),pull_times))
        elseif Dist == "Exponential"
            reward = sum(rand(Exponential(ta), pull_times))
        elseif Dist == "Deterministic"
            reward = ta*pull_times
        end
        marms[arm] += reward
    end
    return marms, npulls
end

function pull_arm!(Dist::String, arm::Arm; pull_times::Int = 1) ## For BSH
    if Dist == "Normal"
        reward = sum(rand(Normal(arm.avg_reward, 0.5^2), pull_times))
    elseif Dist == "Bernoulli"
        reward = sum(rand(Bernoulli(arm.avg_reward), pull_times))
    elseif Dist == "Poisson"
        reward = sum(rand(Poisson(arm.avg_reward), pull_times))
    elseif Dist == "Power"
        reward = sum(rand(Kumaraswamy(1/(1/arm.avg_reward - 1), 1),pull_times))
    elseif Dist == "Exponential"
        reward = sum(rand(Exponential(arm.avg_reward), pull_times))
    elseif Dist == "Deterministic"
        reward = arm.avg_reward*pull_times
    end
    arm.eprc_reward = (arm.eprc_reward * arm.num_pulls + reward) / (arm.num_pulls + pull_times)
    arm.num_pulls += pull_times
end



####Uniform sampling

function uniform_sampling(true_mean::AbstractVector, budget::Int64, Dist::String ; print_seq::Bool = true)
    n = length(true_mean)
    marms = zeros(n)
    npulls = zeros(n)

    maxmean = maximum(true_mean)
    ba_diff = ones(budget) .* (maxmean - minimum(true_mean))
    ba = 1
    M = -1000.0
    for i in 1:budget
        a = i % n
        a = (a == 0.0) ? n : a
        pull_arm!(Dist, a, 1, marms, npulls, true_mean)
        M = marms[ba]/npulls[ba]
        if (M <= marms[a] / npulls[a]) && print_seq
            ba_diff[i:end] .= maxmean - true_mean[a]
            ba = a
        end
    end

    if print_seq
        return ba, ba_diff
    else
        return ba
    end
end



##### Box Thirding


f(r) = r + r^(1.5) - 4
r0 = find_zero(f, (0, 10))

mutable struct BoxUnit
    armidx::Vector{Int64}
    count::Int64
end

mutable struct Box
    boxunits::Vector{BoxUnit}
    level::Int64
    fullcount::Int64
    function Box(k::Int64, r::Int64)
        unit = BoxUnit(Int64[], 0)
        return new([unit], r, 0)
    end      
    function Box(bu, l, f)
        return new(bu, l, f)
    end
end

function thirding(v::AbstractVector, arms; k::Int64=1)
    o = sortperm(v)
    dd = view(o, 1:k)
    kk = view(o, (k+1):2k)
    uu = view(o, (2k+1):3k)
    
    discard = arms[dd]
    keep = arms[kk]
    up = arms[uu]
    return discard, keep, up
end

function up!(Boxes::Vector{Box}, up::Vector{Int64},r::Int64, R::Int64; k::Int64 = 1)
    if r == R
        boxunit = BoxUnit(up, k)
        push!(Boxes, Box([boxunit], R+1, k))
    else
        append!(Boxes[r+1].boxunits[1].armidx, up)
        Boxes[r+1].boxunits[1].count += k
        Boxes[r+1].fullcount += k
    end
    return Boxes
end

function keep!(Boxes::Vector{Box}, keep::Vector{Int64}, r::Int64, u::Int64, U::Int64; k::Int64 = 1)
    if u == U
        expand_box = BoxUnit(keep, k)
        push!(Boxes[r].boxunits, expand_box)
    else
        append!(Boxes[r].boxunits[u+1].armidx, keep)
        Boxes[r].boxunits[u+1].count += k
    end    
    return Boxes
end

function emptying!(Boxes::Vector{Box}, r::Int64, u::Int64, arms::Vector{Int64}; k::Int64 = 1)
    setdiff!(Boxes[r].boxunits[u].armidx, arms)
    Boxes[r].boxunits[u].count -= 3k
    return Boxes
end

function thirding_operation!(Boxes::Vector{Box}, marms::Vector{Float64}, npulls::Vector{Int64}, r::Int64, u::Int64, R::Int64, U::Int64, Dist::String, true_mean::AbstractVector; k::Int64 = 1)
    arms = Boxes[r].boxunits[u].armidx[1:3k]
    discard, keep, up = thirding(marms[arms],arms; k = k) 
    keep!(Boxes, keep, r, u, U; k=1)
    pull_arm!(Dist, up, ceil(Int,ceil(r0^r)*k), marms, npulls, true_mean)
    up!(Boxes, up, r, R; k =1)
    emptying!(Boxes, r, u, arms, k=1)
end


function thirding_operation2!(Boxes::Vector{Box}, marms::Vector{Float64}, npulls::Vector{Int64}, r::Int64, u::Int64, R::Int64, U::Int64, Dist::String, true_mean::AbstractVector, t0::Int64; k::Int64 = 1)
    arms = Boxes[r].boxunits[u].armidx[1:3k]
    discard, keep, up = thirding(marms[arms],arms; k = k) 
    keep!(Boxes, keep, r, u, U; k=1)
    pull_arm!(Dist, up, ceil(Int,ceil(r0^r)*k*t0), marms, npulls, true_mean)
    up!(Boxes, up, r, R; k =1)
    emptying!(Boxes, r, u, arms, k=1)
end


function BoxThirding(true_mean::AbstractVector, budget::Int64, Dist::String ; print_seq::Bool = false, k::Int64 = 1)
    n = length(true_mean)
    marms = zeros(n)
    npulls = zeros(Int, n)
    Boxes = [Box(3k,1)]
    infarm = argmin(true_mean)
    ba_seq = fill(infarm, budget)
    break_idx = false
    budget_left = 0
    arms = zeros(Int, 3k)
    base = 1
    unpulled = 0
    while budget > 0
        #println(budget)
        R = length(Boxes)
        for r in R:-1:base
            U = length(Boxes[r].boxunits)
            for u in U:-1:1
                while Boxes[r].boxunits[u].count >= 3k 
                    if budget < ceil(Int,r0^r)
                        break
                    end
                    thirding_operation!(Boxes, marms, npulls, r, u, R, U, Dist, true_mean; k=1)
                    budget -=ceil(Int,r0^r)
                end
            end
        end

        if Boxes[end].fullcount > 0
            candi = Boxes[end].boxunits[1].armidx
            ba = argmax(marms[candi])
            ba_seq[1:budget] .= candi[ba]
        end
        
        if any(npulls .<= unpulled) && !break_idx && budget >= ceil(Int,r0^(base-1))
            newarm = findfirst(npulls .<= unpulled)
            pull_arm!(Dist, newarm, ceil(Int,r0^(base-1)), marms, npulls, true_mean)
            append!(Boxes[base].boxunits[1].armidx, newarm)
            Boxes[base].boxunits[1].count += 1
            Boxes[base].fullcount += 1
            budget -= ceil(Int,r0^(base-1))
            break_idx = false
        elseif !break_idx && budget >= ceil(Int,r0^base)
            #println(npulls)
            unpulled = minimum(npulls)
            base += 1
            #println("the baseline is now $base and unpulled is $unpulled")
            break_ix = true
        else
            budget = 0
        end
    end


    candidates = Boxes[end].boxunits[1].armidx
    ba = argmax(marms[candidates])
    best_arm = candidates[ba]

    if print_seq
        M = maximum(mean.(true_mean))
        ba_diff = M .- view(true_mean, reverse(ba_seq))
        return best_arm, ba_diff
    else
        return best_arm, reverse(ba_seq)
    end
end


function reset!(arm::Arm)
    arm.num_pulls = 0
    arm.eprc_reward = 0.0
end

mutable struct Sequential_halving_fs
    alive_arm::Vector{Arm}        # List of arms that are still alive
    budget::Int                   # Budget for pulling arms
    num_arms::Int                 # Number of arms
    max_reward::Float64           # Maximum reward among the arms
    sh_done::Bool                 # Indicates if sequential halving is finished
    pull_count_all::Int           # Total count of all pulls (for debugging)
    output_eprc::Float64          # Estimated reward of the best arm
    output_true::Float64          # True reward of the best arm
    pull_index::Int               # Index of the arm to be pulled next
    pull_count_perarm::Int        # Count of pulls for the current arm
    pull_count_perphase::Int      # Count of pulls for the current phase
    phase::Int                    # Current phase of the process
    num_pull_arm::Int             # Number of pulls per arm in the current phase
    num_pull_phase::Int           # Total pulls assigned for the current phase

    function Sequential_halving_fs(avg_reward::Vector{Float64}, budget::Int)
        num_arms = length(avg_reward)
        max_reward = maximum(avg_reward)
        initial_num_pull_arm = floor(Int, budget / (num_arms * ceil(Int, log2(num_arms))))
        initial_num_pull_phase = initial_num_pull_arm * num_arms
        alive_arm = [Arm(reward) for reward in avg_reward]
        return new(alive_arm, budget, num_arms, max_reward, false, 0, 0.0, 0.0, 0, 0, 0, 1, initial_num_pull_arm, initial_num_pull_phase)
    end
end

function sh_onestep_forward!(sh::Sequential_halving_fs, Dist::String)

    if sh.sh_done
        return
    end

    if sh.pull_count_perphase == sh.num_pull_phase
        # Reject the worst half of the arms
        num_rej = floor(Int, length(sh.alive_arm) / 2)
        eprc_reward_alive_arm = [arm.eprc_reward for arm in sh.alive_arm]
        idx = partialsortperm(eprc_reward_alive_arm, 1:num_rej, rev = true)
        #rej_indices = sorted_indices[1:num_rej]

        # Keep the best-performing arms
        sh.alive_arm = [sh.alive_arm[i] for i in idx]

        # Reset surviving arms
        for arm in sh.alive_arm
            reset!(arm)
        end

        # Update phase
        sh.phase += 1
        sh.pull_index = 1
        sh.pull_count_perarm = 0
        sh.pull_count_perphase = 0
        sh.num_pull_arm = floor(Int, sh.budget / (length(sh.alive_arm) * ceil(log2(sh.num_arms))))
        sh.num_pull_phase = sh.num_pull_arm * length(sh.alive_arm)
    end

    if sh.pull_count_perarm == sh.num_pull_arm
        sh.pull_count_perarm = 0
        sh.pull_index += 1
    end

    pull_arm!(Dist,sh.alive_arm[sh.pull_index + 1])
    sh.pull_count_perarm += 1
    sh.pull_count_perphase += 1
    sh.pull_count_all += 1

    if sh.phase == ceil(Int, log2(sh.num_arms)) && sh.pull_index == 1 && sh.pull_count_perarm == sh.num_pull_arm
        @assert length(sh.alive_arm) == 2

        eprc_reward_alive_arm = [arm.eprc_reward for arm in sh.alive_arm]
        output_index = argmax(eprc_reward_alive_arm)
        sh.output_true = sh.alive_arm[output_index].avg_reward
        sh.output_eprc = sh.alive_arm[output_index].eprc_reward
        sh.sh_done = true
    end
end


mutable struct Bracket
    avg_reward::Vector{Float64}   # Vector of average rewards for each arm
    num_arms::Int                 # Number of arms
    max_reward::Float64           # Maximum reward value among the arms
    output_eprc::Float64          # Estimated expected reward for the best arm
    output_true::Float64          # True reward for the best arm
    budget_rcvd::Int              # Total budget received for pulling arms
    num_sr_done::Int              # Number of successive rejections done
    budget::Int                   # Budget for the current Sequential Halving instance
    SH::Sequential_halving_fs     # Sequential Halving instance

    function Bracket(avg_reward::Vector{Float64})
        num_arms = length(avg_reward)
        max_reward = maximum(avg_reward)
        initial_budget = ceil(Int, num_arms * log2(num_arms))  # Starting budget
        sh_instance = Sequential_halving_fs(avg_reward, initial_budget)
        return new(avg_reward, num_arms, max_reward, 0.0, 0.0, 0, 0, initial_budget, sh_instance)
    end
end

function onestep_forward!(bracket::Bracket, Dist::String)
    # Move one step forward in the current Sequential Halving process
    sh_onestep_forward!(bracket.SH, Dist)

    # If Sequential Halving is done, update the bracket and reset with a larger budget
    if bracket.SH.sh_done == 1
        bracket.num_sr_done += 1
        bracket.output_eprc = bracket.SH.output_eprc
        bracket.output_true = bracket.SH.output_true
        bracket.budget *= 2  # Double the budget for the next round

        # Create a new instance of Sequential Halving with the updated budget
        bracket.SH = Sequential_halving_fs(bracket.avg_reward, bracket.budget)

        # Uncomment this line for debugging:
        # println("$(bracket.num_sr_done)th output true: ", bracket.output_true)
    end
end


mutable struct Instance
    avg_reward::Vector{Float64}   # Vector to store average rewards for each arm
    num_arms::Int                 # Number of arms
    max_reward::Float64           # Maximum reward value among arms
    output_eprc::Float64          # Estimated expected reward for the best arm
    output_true::Float64          # True reward for the best arm
    bracket::Vector{Bracket}      # List of Bracket instances
    eprc_reward_allbracket::Vector{Float64}  # List of estimated rewards for all brackets
    pull_index::Int               # Keeps track of the current bracket being pulled

    function Instance(avg_reward::AbstractVector)
        max_reward = maximum(avg_reward)
        return new(avg_reward, length(avg_reward), max_reward, 0.0, 0.0, [], Float64[], 0)
    end
end

function subsample_n!(instance::Instance, n::Int)
    #@assert 0 < n <= instance.num_arms
    subsample_index = StatsBase.sample(1:instance.num_arms, n, replace = false)
    new_bracket = Bracket(instance.avg_reward[subsample_index])
    push!(instance.bracket, new_bracket)
    push!(instance.eprc_reward_allbracket, 0.0)
    return subsample_index
end

function include_all!(instance::Instance)
    all_bracket = Bracket(instance.avg_reward)
    push!(instance.bracket, all_bracket)
    push!(instance.eprc_reward_allbracket, 0.0)
end

function onestep_forward!(instance::Instance, Dist::String)
    # Perform one step forward for the current bracket
    current_bracket = instance.bracket[instance.pull_index + 1]
    onestep_forward!(current_bracket, Dist)

    # Update the eprc reward for the current bracket
    instance.eprc_reward_allbracket[instance.pull_index + 1] = current_bracket.output_eprc

    # Find the bracket with the best estimated reward
    best_index = argmax(instance.eprc_reward_allbracket)
    instance.output_eprc = instance.eprc_reward_allbracket[best_index]
    instance.output_true = instance.bracket[best_index].output_true

    # Update the pull index (circular over the brackets)
    instance.pull_index = (instance.pull_index + 1) % length(instance.bracket)
end


# Function for calculating log_sr
function log_sr(x::Int)
    #@assert x > 0
    return 0.5 + sum(1 / i for i in 2:x)
end

mutable struct Successive_rej
    alive_arm::Vector{Arm}        # List of arms that are still alive (active)
    budget::Int                   # Budget for pulling arms
    num_arms::Int                 # Number of arms (initial number)
    max_reward::Float64           # Maximum reward from the initial arms
    sr_done::Bool                 # Indicates whether the successive rejection process is done
    pull_count_all::Int           # Total number of pulls made across all arms
    output_eprc::Float64          # Estimated expected reward of the best arm
    output_true::Float64          # True reward of the best arm
    pull_index::Int               # Index of the arm to be pulled next
    pull_count_perarm::Int        # Number of pulls for the current arm
    pull_count_perphase::Int      # Total pulls for the current phase
    phase::Int                    # Current phase of the successive rejection
    num_pull_arm::Int             # Number of pulls per arm in the current phase
    num_pull_phase::Int           # Total number of pulls allowed in the current phase

    function Successive_rej(avg_reward::Vector{Float64}, budget::Int)
        num_arms = length(avg_reward)
        max_reward = maximum(avg_reward)
        initial_num_pull_arm = ceil(Int, (budget - num_arms) / ((num_arms + 1 - 1) * log_sr(num_arms)))
        initial_num_pull_phase = initial_num_pull_arm * num_arms
        alive_arm = [Arm(reward) for reward in avg_reward]
        return new(alive_arm, budget, num_arms, max_reward, false, 0, 0.0, 0.0, 0, 0, 0, 1, initial_num_pull_arm, initial_num_pull_phase)
    end
end

function sh_onestep_forward!(sh::Sequential_halving_fs, Dist::String)

    if sh.sh_done
        return
    end

    if sh.pull_count_perphase == sh.num_pull_phase
        # Reject the worst half of the arms
        num_rej = floor(Int, length(sh.alive_arm) / 2)
        eprc_reward_alive_arm = [arm.eprc_reward for arm in sh.alive_arm]
        idx = partialsortperm(eprc_reward_alive_arm, 1:num_rej, rev = true)
        #rej_indices = sorted_indices[1:num_rej]

        # Keep the best-performing arms
        sh.alive_arm = [sh.alive_arm[i] for i in idx]

        # Reset surviving arms
        for arm in sh.alive_arm
            reset!(arm)
        end

        # Update phase
        sh.phase += 1
        sh.pull_index = 0
        sh.pull_count_perarm = 0
        sh.pull_count_perphase = 0
        sh.num_pull_arm = floor(Int, sh.budget / (length(sh.alive_arm) * ceil(log2(sh.num_arms))))
        sh.num_pull_phase = sh.num_pull_arm * length(sh.alive_arm)
    end

    if sh.pull_count_perarm == sh.num_pull_arm
        sh.pull_count_perarm = 0
        sh.pull_index += 1
    end

    pull_arm!(Dist,sh.alive_arm[sh.pull_index + 1])
    sh.pull_count_perarm += 1
    sh.pull_count_perphase += 1
    sh.pull_count_all += 1

    if sh.phase == ceil(Int, log2(sh.num_arms)) && sh.pull_index == 1 && sh.pull_count_perarm == sh.num_pull_arm
        @assert length(sh.alive_arm) == 2

        eprc_reward_alive_arm = [arm.eprc_reward for arm in sh.alive_arm]
        output_index = argmax(eprc_reward_alive_arm)
        sh.output_true = sh.alive_arm[output_index].avg_reward
        sh.output_eprc = sh.alive_arm[output_index].eprc_reward
        sh.sh_done = true
    end
end


function BSH(true_mean::Vector{Float64}, budget::Int, Dist::String)
    instance = Instance(true_mean)
    start_size = 2
    num_bracket = 0
    stop_open = 0

    simple_regret = zeros(budget)
    true_max = maximum(true_mean)
    for t in 1:budget

        # Opening new bracket logic
        if stop_open == 0 && t > num_bracket * (2 ^ num_bracket) * start_size / 2
            new_size = (2 ^ num_bracket) * start_size

            if new_size < length(true_mean)
                subsample_n!(instance, new_size)
                num_bracket += 1
            else
                include_all!(instance)
                num_bracket += 1
                stop_open = 1
            end
        end

        # Perform one step forward
        onestep_forward!(instance, Dist)

        # Compute simple regret
        simple_regret[t] = true_max - instance.output_true
    end
    
    ba = findfirst(instance.output_true .== true_mean)
    return ba, simple_regret
end



function Up(times::Int64, δ::AbstractFloat)
    if times <= 1
        u = 0.0
    else
        u = sqrt(1/times * log(log(times)/δ))
    end
    return u
end

mutable struct Bracket2
    arms::Vector{Int64}
    npulls::Vector{Int64}
    marms::Vector{Float64}

    best_arm::Int64
    LCB_best_arm::Float64
    function Bracket2(arms::Vector{Int64})
        n = length(arms)
        ba = rand(arms)
        return new(arms, zeros(Int, n), zeros(n), ba, -1000.0)
    end
end

function update_bracket2!(Bracket::Vector{Bracket2}, i::Int64, true_mean::AbstractVector, δ::AbstractFloat, Dist::String)
    bracket = Bracket[i]
    begin view
        arms = bracket.arms
        npulls = bracket.npulls
        marms = bracket.marms
        tm = true_mean[arms]
    end

    if any(npulls .<= 1)
        arm = rand(findall(npulls .<= 1))
        pull_arm!(Dist,arm, 1,marms, npulls, tm)
    else
        m = marms ./ npulls
        u = Up.(npulls, δ)
        _, arm = findmax( m .+ u)
        pull_arm!(Dist,arm, 1, marms, npulls, tm)
    end

    if any(npulls .<= 1)
        L = -1000.0
        
        _, ba = findmax(marms)
    else
        m = marms ./ npulls
        u = Up.(npulls, δ/i^2/length(arms))
        L, ba = findmax(m .- u)
    end

    bracket.npulls = npulls
    bracket.marms = marms
    bracket.best_arm = arms[ba]
    bracket.LCB_best_arm = L
    return bracket
end

function BUCB(δ::AbstractFloat, true_mean::AbstractVector, budget::Int64, Dist::String; print_seq::Bool = false)
    n = length(true_mean)
    firstarm = StatsBase.sample(1:n, 2, replace = false)
    S = [Bracket2(firstarm)]
    t = 0
    B = 1
    BA = [S[1].best_arm]
    L = [S[1].LCB_best_arm]

    ba_seq = zeros(Int, budget)
    while t < budget
        if t >= B*2^B
            B += 1
            arms = StatsBase.sample(1:n, min(n, 2^B), replace = false)
            push!(S, Bracket2(arms))
            push!(BA, S[end].best_arm)
            push!(L, -1000.0)
        end
        for i in 1:length(S)
            if t == budget
                break
            end
            t += 1
            S[i] = update_bracket2!(S, i, true_mean, δ, Dist)
            BA[i] = S[i].best_arm
            L[i] = S[i].LCB_best_arm
            if print_seq
                ba_seq[t] = BA[findmax(L)[2]]
            end
        end
    end

    best_arm = BA[findmax(L)[2]]
    
    if print_seq
        ba_diff = maximum(mean.(true_mean)) .- mean.(true_mean[ba_seq])
        return best_arm, ba_diff
    else
        return best_arm
    end
end
