using Distributions
using LinearAlgebra
using Random, Statistics
using Clustering, Distances

function upsampling_balancing_data(
        x1::Vector{Vector{Float64}},
        x2::Vector{Vector{Float64}},
        subset_num::Int64;
        seed=2025,
        EPS=1e-6
    )
    Random.seed!(seed)
    N1 = length(x1)
    D = length(x1[1])
    
    
    if subset_num == 0 || subset_num > N1
        subset_num = N1
    end
    
    if subset_num != N1
        sub_ids = sample(1:N1, subset_num, replace=false)
    else
        sub_ids = 1:N1
    end
    
    x1_sub = x1[sub_ids]
    n1 = subset_num
    n2 = length(x2)
    @assert n1 <= n2 "n1 must be smaller than or equal to n2"
    
    balanced = Vector{Vector{Float64}}(undef, n2)
    ids1 = Vector{Int}(undef, n2)
    @inbounds for i in 1:n1
        balanced[i] = copy(x1_sub[i])
        ids1[i] = i
    end
    if n2 <= 2*n1
        add_count = n2 - n1
        add_ids = shuffle(1:n1)[1:add_count]
        @inbounds for k in 1:add_count
            id = add_ids[k]
            base = x1_sub[id]
            vec = similar(base)
            for j in 1:D
                vec[j] = base[j] + EPS * randn()
            end
            idx = n1 + k
            balanced[idx] = vec
            ids1[idx] = id
        end
    else
        rep_count = div(n2, n1)
        offset = n1
        @inbounds for rep in 2:rep_count
            for i in 1:n1
                base = x1_sub[i]
                vec = similar(base)
                for j in 1:D
                    vec[j] = base[j] + EPS * randn()
                end
                offset += 1
                balanced[offset] = vec
                ids1[offset] = i
            end
        end
        remaining = n2 - rep_count*n1
        add_ids = shuffle(1:n1)[1:remaining]
        @inbounds for k in 1:remaining
            base = x1_sub[add_ids[k]]
            vec = similar(base)
            for j in 1:D
                vec[j] = base[j] + EPS * randn()
            end
            offset += 1
            balanced[offset] = vec
            ids1[offset] = add_ids[k]
        end
    end
    return balanced, sub_ids[ids1]
end


function upsampling_balancing_data_with_centers(
        x1::Vector{Vector{Float64}},
        x2::Vector{Vector{Float64}},
        subset_num::Int64;
        seed=2025,
        EPS=1e-6
    )

    Random.seed!(seed)

    x1 = reduce(vcat, (v' for v in x1))
    x2 = reduce(vcat, (v' for v in x2))

    if subset_num != size(x1, 1)
        sub_ids = sample(1:size(x1, 1), subset_num, replace=false)
    else
        sub_ids = 1:size(x1, 1)
    end
    x1 = x1[sub_ids,:]
    n1 = size(x1, 1)
    n2 = size(x2, 1)
    d = size(x2, 2)
    n_max = max(n1, n2)

    @assert n_max == n2 "n1 must be smaller than n2"

    # Upsample x1
    ids1 = 1:n1
    if n2 <= 2 * n1
        # Case 1: n0 < 2*n1
        add_count = n2 - n1
        DM = pairwise(SqEuclidean(), x1; dims=1)
        res = kmedoids(DM, add_count)
        medoid_idxs = res.medoids
        medoids = x1[medoid_idxs, :]
        balanced_x1 = vcat(x1, medoids .+ EPS * randn(add_count, d))
        balanced_ids1 = vcat(ids1, medoid_idxs)
    else
        # Case 2: n0 > M*n1
        repeat_count = div(n2, n1) # M
        remaining = n2 - repeat_count * n1
        repeated_x1 = repeat(x1, inner=(repeat_count-1, 1)) .+ EPS * randn((repeat_count-1)*n1, d)
        repeated_x1 = vcat(x1, repeated_x1)
        repeated_ids1 = repeat(ids1, inner=repeat_count)
        DM = pairwise(SqEuclidean(), x1; dims=1)
        res = kmedoids(DM, remaining)
        medoid_idxs = res.medoids
        medoids = x1[medoid_idxs, :]
        balanced_x1 = vcat(repeated_x1, medoids .+ EPS * randn(remaining, d))
        balanced_ids1 = vcat(repeated_ids1, medoid_idxs)
    end

    upsampled_data = [collect(row) for row in eachrow(balanced_x1)]
    upsampled_ids = balanced_ids1

    return upsampled_data, sub_ids[upsampled_ids]
end