using LinearAlgebra
using Random
using OptimalTransport
using Tulip
using Distances
using Statistics
using StatsBase
using Munkres
using PyCall


function ot_map_indices(data1::Matrix{<:Real}, data2::Matrix{<:Real}, ids1::Vector{Int}, ids2::Vector{Int}, p::Real)
    @assert size(data1, 1) == size(data2, 1) "Both datasets must have the same number of points"
    n = size(data1, 1)
    D = size(data1, 2)
    cost_matrix = Array{Float64}(undef, n, n)
    
    @inbounds for i in 1:n
        @inbounds for j in 1:n
            s = 0.0
            @inbounds for d in 1:D
                x = data1[i, d] - data2[j, d]
                s += x * x
            end
            cost_matrix[i, j] = s
        end
    end
    
    pot = pyimport("ot")
    
    μ = fill(1/n, n)
    ν = fill(1/n, n)
    
    ot_plan = pot.emd(μ, ν, cost_matrix, numItermax=500000)
    optimal_indices = mapslices(argmax, ot_plan; dims=2)[:]

    return optimal_indices
end


function block_ot_map_indices(data1::Matrix{<:Real}, data2::Matrix{<:Real}, p::Real, block_size::Int)
    n1, n2 = size(data1, 1), size(data2, 1)
    @assert n1 == n2 "Both datsets must have the same number of points (block)"
    T = zeros(Int, n1)

    ids1_rand = randperm(n1)
    ids2_rand = randperm(n2)

    num_blocks = div(n1, block_size)

    for i in 1:num_blocks
        start1 = (i - 1) * block_size + 1
        end1 = i * block_size
        block_ids1 = ids1_rand[start1:end1]

        start2 = (i - 1) * block_size + 1
        end2 = i * block_size
        block_ids2 = ids2_rand[start2:end2]

        data1_block = data1[block_ids1, :]
        data2_block = data2[block_ids2, :]

        T[block_ids1] = ot_map_indices(data1_block, data2_block, collect(1:block_size), block_ids2, p)
    end

    remainder_start = num_blocks * block_size + 1
    rem_ids1 = ids1_rand[remainder_start:n1]
    rem_ids2 = ids2_rand[remainder_start:n1]
    T[rem_ids1] = rem_ids2[randperm(length(rem_ids2))]

    return T
end

function partial_ot_map_indices(data1::Matrix{<:Real}, 
                                data2::Matrix{<:Real}, 
                                p::Real, 
                                T_current::Vector{Int}, 
                                ratio::Float64,
                                block_size::Union{Nothing,Int}=nothing)

    @assert 0 < ratio <= 1.0 "Ratio must be in the range (0, 1]"
    n = size(data1, 1)
    @assert length(T_current) == n "T_current must match the number of points in data1"

    subset_size = Int(floor(n * ratio))

    selected_indices = sample(1:n, subset_size, replace=false)
    data1_subids = selected_indices
    data2_subids = T_current[selected_indices]
    D = size(data1, 2)
    data1_subset = Array{eltype(data1)}(undef, subset_size, D)
    data2_subset = Array{eltype(data2)}(undef, subset_size, D)
    @inbounds for k in 1:subset_size
        data1_subset[k, :] = data1[data1_subids[k], :]
        data2_subset[k, :] = data2[data2_subids[k], :]
    end

    optimal_subset_map = if block_size !== nothing
        block_ot_map_indices(data1_subset, data2_subset, p, block_size)
    else
        ot_map_indices(data1_subset, data2_subset, 1:subset_size, 1:subset_size, p)
    end

    optimal_subset_map = [data2_subids[i] for i in optimal_subset_map]

    updated_ids = deepcopy(T_current)
    updated_ids[data1_subids] = optimal_subset_map
    return updated_ids
end

function partial_ot_map_indices(data1::Vector{Vector{Float64}}, 
    data2::Vector{Vector{Float64}}, 
    p::Real, 
    T_current::Vector{Int}, 
    ratio::Float64,
    block_size::Union{Nothing,Int}=nothing)

    @assert 0 < ratio <= 1.0 "Ratio must be in the range (0, 1]"
    n = size(data1, 1)
    @assert length(T_current) == n "T_current must match the number of points in data1"

    subset_size = Int(floor(n * ratio))

    selected_indices = randperm(n)[1:subset_size]
    data1_subset = data1[selected_indices, :]
    data1_subids = selected_indices

    data2_subset = [ data2[T_current[i]] for i in selected_indices ]
    data2_subids = [ T_current[i] for i in selected_indices ]

    data1_subset = reduce(vcat, (v' for v in data1_subset))
    data2_subset = reduce(vcat, (v' for v in data2_subset))

    optimal_subset_map = if block_size !== nothing
        block_ot_map_indices(data1_subset, data2_subset, p, block_size)
    else
        ot_map_indices(data1_subset, data2_subset, collect(1:subset_size), collect(1:subset_size), p)
    end

    optimal_subset_map = [data2_subids[i] for i in optimal_subset_map]

    updated_ids = deepcopy(T_current)
    subset_T_current = T_current[selected_indices]
    updated_ids[selected_indices] = optimal_subset_map

    return updated_ids
end

function propose_T(data1::Matrix{<:Real}, data2::Matrix{<:Real};
                   p::Real=2,
                   T_current::Union{Nothing,Vector{Int}}=nothing,
                   method::String="partial_ot",
                   ot_ratio::Float64=0.0,
                   swap_iter::Int64=1, 
                   swap_set_size::Int64=2,
                   block_size::Union{Nothing,Int}=nothing)

    if method == "ot"
        if block_size !== nothing
            return block_ot_map_indices(data1, data2, p, block_size)
        end
        n = size(data1, 1)
        return ot_map_indices(data1, data2, 1:n, 1:n, p)

    elseif method == "partial_ot"
        @assert T_current !== nothing "T_current must be provided for 'partial_ot'"
        @assert ot_ratio > 0.0 "ot_ratio must be > 0 for partial OT"
        return partial_ot_map_indices(data1, data2, p, T_current, ot_ratio, block_size)

    elseif method == "random"
        n = size(data1, 1)
        return randperm(n)

    elseif method == "random_replace"
        n = size(data1, 1)
        return rand(1:n, n)

    elseif method == "single_random"
        @assert T_current !== nothing "T_current must be provided for 'single_random'"
        T = deepcopy(T_current)
        idx = rand(1:length(T))
        T[idx] = rand(1:length(T))
        return T

    elseif method == "sequential"
        @assert T_current !== nothing "T_current must be provided for 'sequential'"
        T = deepcopy(T_current)
        for _ in 1:swap_iter
            idx = sample(1:length(T), swap_set_size, replace=false)
            shuffled_val = shuffle(T[idx])
            T[idx] = shuffled_val
        end
        return T

    elseif method == "sequential_binary"
        @assert T_current !== nothing "T_current must be provided for 'sequential_binary'"
        n = length(T_current)
        T = deepcopy(T_current)
        for _ in 1:swap_iter
            idxs = sample(1:n, 2, replace=false)
            i, j = idxs
            T[i], T[j] = T[j], T[i]
        end
        return T

    else
        error("Method not implemented. Choose from 'ot', 'partial_ot', 'random', 'random_replace', 'sequential', 'sequential_binary', 'single_random'.")
    end
end

function propose_T(data1::Vector{Vector{Float64}}, data2::Vector{Vector{Float64}};
                   p::Real=2,
                   T_current::Union{Nothing,Vector{Int}}=nothing,
                   method::String="partial_ot",
                   ot_ratio::Float64=0.0,
                   swap_iter::Int64=1,
                   swap_set_size::Int64=2,
                   block_size::Union{Nothing,Int}=nothing)

    
    if method == "ot"
        data1 = reduce(vcat, (v' for v in data1))
        data2 = reduce(vcat, (v' for v in data2))
        if block_size !== nothing
            return block_ot_map_indices(data1, data2, p, block_size)
        end
        return ot_map_indices(data1, data2, 1:size(data1,1), 1:size(data2,1), p)

    elseif method == "partial_ot"
        @assert T_current !== nothing "T_current must be provided for 'partial_ot'"
        @assert ot_ratio > 0.0 "ot_ratio must be > 0 for partial OT"
        return partial_ot_map_indices(data1, data2, p, T_current, ot_ratio, block_size)

    elseif method == "random"
        n = length(data1)
        return randperm(n)

    elseif method == "random_replace"
        n = length(data1)
        return rand(1:n, n)

    elseif method == "single_random"
        @assert T_current !== nothing "T_current must be provided for 'single_random'"
        T = deepcopy(T_current)
        idx = rand(1:length(T))
        T[idx] = rand(1:length(T))
        return T

    elseif method == "sequential"
        @assert T_current !== nothing "T_current must be provided for 'sequential'"
        T = deepcopy(T_current)
        for _ in 1:swap_iter
            idx = sample(1:length(T), swap_set_size, replace=false)
            shuffled_val = shuffle(T[idx])
            T[idx] = shuffled_val
        end
        return T

    elseif method == "sequential_binary"
        @assert T_current !== nothing "T_current must be provided for 'sequential_binary'"
        n = length(T_current)
        T = deepcopy(T_current)
        for _ in 1:swap_iter
            idxs = sample(1:n, 2, replace=false)
            i, j = idxs
            T[i], T[j] = T[j], T[i]
        end
        return T

    else
        error("Method not implemented. Choose from 'ot', 'partial_ot', 'random', 'random_replace', 'sequential', 'sequential_binary', 'single_random'.")
    end
end

function propose_E(E_current::Vector{Int}, n::Int)::Vector{Int}
    @assert !isempty(E_current) "E_current must be non-empty"
    outside = setdiff(1:n, E_current)
    @assert !isempty(outside) "No elements outside E_current to swap"
    i_in = rand(E_current)
    i_out = rand(outside)
    new_E = deepcopy(E_current)
    idx = findfirst(x -> x == i_in, new_E)
    new_E[idx] = i_out
    return new_E
end

function mask(
    T_current::Union{Nothing,Vector{Int}}=nothing,
    subset_num::Int64=2)
    n = length(T_current)
    T = deepcopy(T_current)
    subset_j = sample(1:n, subset_num, replace=false)

    subset_i = findall(x->x in subset_j, T)
    new_j = shuffle(subset_j)
    for (i,j) in zip(subset_i, new_j)
        T[i] = j
    end
    return T
end

function aggregate_T_T0(
    T::Union{Nothing,Vector{Int}}=nothing,
    T0::Union{Nothing,Vector{Int}}=nothing,
    E::Union{Nothing,Vector{Int}}=nothing)

    TT = deepcopy(T)
    for i in E
        TT[i] = T0[i]
    end
    return TT
end


function compute_energy(data1::Matrix{<:Real}, data2::Matrix{<:Real};
                                    p::Real=2, T::Vector{Int}, temperature::Real=1.0)::Real
    N, D = size(data1)
    total = zero(eltype(data1))
    @inbounds for i in 1:N
        s = zero(eltype(data1))
        @inbounds for d in 1:D
            x = data1[i, d] - data2[T[i], d]
            s += abs(x)^p
        end
        total += s
    end
    energy_mean = total / N
    return exp(-energy_mean / temperature)
end

function compute_energy(data1::Vector{Vector{Float64}}, data2::Vector{Vector{Float64}};
                                    p::Real=2, T::Vector{Int}, temperature::Real=1.0)::Real
    N = length(data1)
    total = zero(eltype(data1[1]))
    @inbounds for i in 1:N
        v1 = data1[i]
        v2 = data2[T[i]]
        s = zero(eltype(v1))
        @inbounds for j in 1:length(v1)
            x = v1[j] - v2[j]
            s += abs(x)^p
        end
        total += s
    end
    energy_mean = total / N
    return exp(-energy_mean / temperature)
end