using Distributed 
addprocs(10)
@everywhere using ProgressMeter

@everywhere function run_profile(info)
    profile, loss_function, repeat = info

    path_example2d_test_error(profile, loss_function, repeat) = 
    "data/example2d_nsde_$(profile)_truth_$(profile)_$(loss_function)_repeat_$(repeat)_losses.csv"
    if isfile(path_example2d_test_error(profile, loss_function, repeat))
        (run∘pipeline)(`python example2d_base.py --profile $profile --loss_function $loss_function --repeat $repeat`)
        @info "The profile $(profile) with loss function $(loss_function) and $(repeat) repeats has been completed. No overwrite, but reevaluate the final model"
        return 
    else
        @info "Skip the profile $(profile) with loss function $(loss_function) and $(repeat) repeats only for evaluation purpose"
        return
    end
    # count = 0
    # for i in 1:30 
    #     if isfile(path_example2d_test_error(profile, loss_function, i))
    #         count += 1
    #     end
    # end

    # if count ≥ 20
    #     @info "The profile $(profile) with loss function $(loss_function) and 10 repeats has been completed. Skip."
    #     return 
    # end
    # @info "The profile $(profile) with loss function $(loss_function) and $(repeat) repeats has not been completed. Start."
    # try
    # (run∘pipeline)(`python example2d_base.py --profile $profile --loss_function $loss_function --repeat $repeat --overwrite true`)
    # catch 
    #     print("error")
    # end
end

# profiles1 = [
#    "n_sample_50",
#    "n_sample_100",
#    "n_sample_200",
#    "n_sample_400",
# ]

profiles = [
    "n_samples_64",
    "n_samples_128",
    "n_samples_256",
    # "n_samples_512",
    # "n_samples_1024", 
]
# profiles2 = [
#     "sigma_$(replace("$i", "." => "_"))" for i in 0.0:0.1:1.0
# ]


quad_profiles = [ # large memory required
    "n_samples_512",
    # "n_samples_1024",
]
# concat the two profiles
#profiles = vcat(profiles1, profiles2)

loss_functions = [
    "W2",
    "W2_rotated",
    #"MMD",
    "sliced_W2",
    "W2_rotated_corrected"
]

repeats = 1:30
    
infos = [
    (profile, loss_function, repeat) for profile in profiles, 
    loss_function in loss_functions, 
    repeat in repeats
]

quad_infos = [
    (profile, loss_function, repeat) for profile in quad_profiles, 
    loss_function in loss_functions, 
    repeat in repeats
]

@info "The following profiles will be executed:" infos

try
    # for profile in profiles
    #     @show profile
    #     run_profile(profile)
    # end
    progress_pmap(run_profile, 
    infos; progress = Progress(length(infos), color=:green, showspeed = true,barglyphs=BarGlyphs("[=> ]")))

    rmprocs(workers()[5:end])
    progress_pmap(run_profile,
    quad_infos; progress = Progress(length(quad_infos), color=:green, showspeed = true,barglyphs=BarGlyphs("[=> ]")))
    # @time pmap(run_profile, profiles)
    # the content of mail_success is 
    mail_body = """
    From: "Automata" <automata@ryzen>
    To: "Xiangting Li" <xiangting.li@ucla.edu>
    Subject: wasserstein training task completed

    Hi Xiangting,

    The tasks scheduled in wasserstein training are completed.
    The profiles are:
        - $(join(profiles, "\n    - "))

    Please check the result.
    Automata.
    """
    open("mail_success.txt", "w") do f
        write(f, mail_body)
    end
    run(pipeline(`ssmtp xiangting.li@ucla.edu`; stdin = "mail_success.txt"))
    # remove the mail_success
    rm("mail_success.txt")
catch e
    mail_body = """
    From: "Automata" <automata@ryzen>
    To: "Xiangting Li" <xiangting.li@ucla.edu>
    Subject: wasserstein training task failed

    Hi Xiangting,

    The tasks scheduled in wasserstein training encountered a problem. Please review the logs and fix the problem.
    $(e)
    The profiles are:
        - $(join(profiles, "\n     -"))
    
    Automata.
    """
    open("mail_failure.txt", "w") do f
        write(f, mail_body)
    end
    run(pipeline(`ssmtp xiangting.li@ucla.edu`; stdin = "mail_failure.txt"))
    # remove the mail_failure
    rm("mail_failure.txt")
end

