# write profiles for cir model, varying initial noise δ
function cir_truth_content_noise(δ)
    content_string ="""
import torch
a, b, σ = 1, 5, 0.5
n_samples = 200
t_size = 41
δ = $(δ)
ts = torch.linspace(0, 2, t_size)

torch.manual_seed(0)
# initial condition
u0 = torch.ones(n_samples, 1) * 2
noise = torch.randn(u0.size()) * torch.sqrt(torch.tensor(δ))
u0 += noise
u0[u0 < 0] = 1e-3

truth_label = 'truth_delta_$(δ)'
u_truth_savepath = 'data/cix_truth_delta_$(δ).pt'
"""


file_name = "profiles/cix_truth_delta_$(replace("$δ", "." => "_")).py"
return content_string, file_name
end

δs = 0.0:0.1:1.0 |> collect
for δ in δs
    content_string, file_name = cir_truth_content_noise(δ)
    open(file_name, "w") do f
        println(f, content_string)
    end
end

# write nsde profiles for cir model, varying hiddent size from 16 to 128
function cir_nsde_content_hidden_size(hidden_size)
    content_string ="""
# Stores the parameters for the cix_nsde_base profile
# Definition of the network see utils/sde_utils.py

batch_size, state_size, brownian_size = 200, 1, 1
hidden_size = $(hidden_size)
η = 0.002
β = (0.9, 0.999)
weight_decay = 0.005
N_epoch = 2000
nsde_label = 'cix_nsde_width_$(hidden_size)'
checkpoint_freq = 100
"""
    file_name = "profiles/cix_nsde_width_$(hidden_size).py"
    return content_string, file_name
end

hidden_sizes = [16, 32, 64, 128]
for hidden_size in hidden_sizes
    content_string, file_name = cir_nsde_content_hidden_size(hidden_size)
    open(file_name, "w") do f
        println(f, content_string)
    end
end
    
function cir_nsde_content_layer(layer)
    content_string ="""
# Stores the parameters for the cix_nsde_base profile
# Definition of the network see utils/sde_utils.py

batch_size, state_size, brownian_size = 200, 1, 1
hidden_size = 32
layers = $(layer)
η = 0.002
β = (0.9, 0.999)
weight_decay = 0.005
N_epoch = 2000
nsde_label = 'cix_nsde_layers_$(layer)'
checkpoint_freq = 100
"""
    file_name = "profiles/cix_nsde_layers_$(layer).py"
    return content_string, file_name
end

layers = [2,3,4,5,6,7,8]
for layer in layers
    content_string, file_name = cir_nsde_content_layer(layer)
    open(file_name, "w") do f
        println(f, content_string)
    end
end

function cir_nsde_resnet_content_layer(layer)
    content_string ="""
# Stores the parameters for the cix_nsde_base profile
# Definition of the network see utils/sde_utils.py

batch_size, state_size, brownian_size = 200, 1, 1
hidden_size = 32
layers = $(layer)
resnet = True
η = 0.002
β = (0.9, 0.999)
weight_decay = 0.005
N_epoch = 2000
nsde_label = 'cix_nsde_resnet_layers_$(layer)'
checkpoint_freq = 100
"""
    file_name = "profiles/cix_nsde_resnet_layers_$(layer).py"
    return content_string, file_name
end

layers = [2,3,4,5]
for layer in layers
    content_string, file_name = cir_nsde_resnet_content_layer(layer)
    open(file_name, "w") do f
        println(f, content_string)
    end
end





# write profiles for example2d, varying sample sizes
function content(n_samples)
content_string ="""
import torch

mu1 = 0.1
mu2 = 0.2
Sigma = [[0.2, -0.1], [-0.1, 0.1]]
n_samples = $(n_samples)
t_size = 101

ts = torch.linspace(0, 5, t_size)

torch.manual_seed(0)
u0 = torch.ones(n_samples, 2) * torch.tensor([1, 0.5])

truth_label = 'truth_n_samples_$(n_samples)'
u_truth_savepath = 'data/example2d_truth_n_samples_$(n_samples).pt'
 """
    return content_string
end

function nsde_profile_content(n_samples)
    content_string ="""
    # Stores the parameters for the cix_nsde_base profile
# Definition of the network see utils/sde_utils.py

batch_size, state_size, brownian_size = $(n_samples), 2, 2
hidden_size = 100
η = 0.0005
β = (0.9, 0.999)
weight_decay = 0.005
N_epoch = 2000
nsde_label = 'example2d_nsde_n_samples_$(n_samples)'
checkpoint_freq = 100
"""
    return content_string
end

function filename(n_samples)
    truth_label = "truth_n_samples_$(n_samples)"
    # replace . by _
    truth_label = replace(truth_label, "." => "_")
    return "profiles/example2d_$(truth_label).py"
end

function nsde_filename(n_samples)
    nsde_label = "nsde_n_samples_$(n_samples)"
    # replace . by _
    nsde_label = replace(nsde_label, "." => "_")
    return "profiles/example2d_$(nsde_label).py"
end

N_samples = [64,128,256,512,1024]

for σ in N_samples
    open(filename(σ), "w") do f
        println(f, content(σ))
    end
    open(nsde_filename(σ), "w") do f
        println(f, nsde_profile_content(σ))
    end
end

# function wgan_ncde_content(n_samples)
#     content_string ="""
#     batch_size, state_size, brownian_size = $(n_samples), 1, 1
#     hidden_size = 32
#     mlp_size = 32
#     mlp_layers = 2
#     η = 2e-3
#     β = (0.9, 0.999)
#     weight_decay = 0.005
#     n_epoch_per_epoch = 1
#     nsde_label = 'wgan_ncde_n_samples_$(n_samples)'
#     swa_start = 1000
#     """
#     return content_string
# end

# N_samples = [64,128,256,512,1024]

# function wgan_ncde_file_name(n_samples)
#     return "profiles/temporal_OU_wgan_ncde_n_samples_$(n_samples).py"
# end

# for σ in N_samples
#     open(wgan_ncde_file_name(σ), "w") do f
#         println(f, wgan_ncde_content(σ))
#     end
# end


function example2d_nsde_n_rotate_content(n_rotate)
    content_string ="""
# Stores the parameters for the example2d_nsde_base profile
# Definition of the network see utils/sde_utils.py

batch_size, state_size, brownian_size = 200, 2, 2
hidden_size = 100
n_rotate = $(n_rotate)
η = 0.0005
β = (0.9, 0.999)
weight_decay = 0.005
N_epoch = 2000
nsde_label = 'example2d_nsde_n_rotate_$(n_rotate)'
checkpoint_freq = 100
"""
    file_name = "profiles/example2d_nsde_n_rotate_$(n_rotate).py"
    return content_string, file_name
end

for n_rotate in 1:9
    content_string, file_name = example2d_nsde_n_rotate_content(n_rotate)
    open(file_name, "w") do f
        println(f, content_string)
    end
end
    
function gbm_nsde_n_rotate_content(n_rotate)
    content_string ="""
# Stores the parameters for the gbm_nsde_base profile
# Definition of the network see utils/sde_utils.py

batch_size, state_size, brownian_size = 200, 2, 2
hidden_size = 100
n_rotate = $(n_rotate)
η = 0.0005
β = (0.9, 0.999)
weight_decay = 0.005
N_epoch = 2000
nsde_label = 'gbm_nsde_n_rotate_$(n_rotate)'
checkpoint_freq = 100
"""
    file_name = "profiles/gbm_nsde_n_rotate_$(n_rotate).py"
    return content_string, file_name
end

for n_rotate in 1:9
    content_string, file_name = gbm_nsde_n_rotate_content(n_rotate)
    open(file_name, "w") do f
        println(f, content_string)
    end
end
    