import yaml
import os

# 定义所有OU参数组合
# 格式: (theta1, mu1, sigma1, theta2, mu2, sigma2)
params_list = [
    (0.1, -0.50, 0.10, 0.5, -2.00, 0.10),
    (0.1, -0.50, 0.10, 1.0, -2.00, 0.05),
    (0.1, -0.50, 0.50, 0.2,  5.00, 0.20),
    (0.1, -0.50, 1.00, 2.0, -2.00, 1.00),
    (0.1, -1.00, 0.30, 0.2,  2.00, 0.50),
    (0.1, -1.00, 0.50, 0.1, -5.00, 0.40),
    (0.1, -1.00, 0.50, 0.2, -5.00, 0.50),
    (0.1, -1.00, 1.00, 1.0, -5.00, 0.50),
    (0.1, -2.00, 0.05, 1.0, -0.50, 0.20),
    (0.1, -2.00, 0.40, 0.1,  2.00, 0.40),
    (0.1, -2.00, 0.20, 0.5,  0.00, 0.10),
    (0.1, -2.00, 0.30, 1.0, -5.00, 0.50),
    (0.1, -2.00, 0.40, 0.1,  2.00, 0.40),
    (0.1, -2.00, 1.00, 2.0, -1.00, 0.20),
]

# 基础配置
base_config = {
    "input_dim": 1,
    "output_dim": 1,
    "input_len": 100,
    "output_len": 50,
    "d_model": 256,
    "n_layer": 4,
    "batch_size": 2048,
    "lr": 1.0e-4,
    "num_epochs": 100,
    "train_ratio": 0.8,
    "val_ratio": 0.1,
    "data": {
        "type": "ou",  # 使用OU过程
        "num_samples": 50000,
        "save_dir": "./data"
    },
    "save": {
        "save_dir": "./checkpoints",
        "log_dir": "./logs",
        "keep_models": 3
    },
    "device": "cuda:0",
    "test_gpu": 0
}

# 创建配置目录
config_dir = "ou_two_sde"
os.makedirs(config_dir, exist_ok=True)

# 生成配置文件
for i, (theta1, mu1, sigma1, theta2, mu2, sigma2) in enumerate(params_list, 1):
    config = base_config.copy()
    config["data"] = config["data"].copy()  # 深拷贝data字典
    
    # 添加OU过程的参数
    config["data"]["theta1"] = theta1
    config["data"]["mu1"] = mu1
    config["data"]["sigma1"] = sigma1
    config["data"]["theta2"] = theta2
    config["data"]["mu2"] = mu2
    config["data"]["sigma2"] = sigma2
    
    # 设置模型名称
    model_name = f"ou_dist_theta1={theta1}_mu1={mu1}_sigma1={sigma1}_theta2={theta2}_mu2={mu2}_sigma2={sigma2}"
    config["save"]["model_name"] = model_name
    
    # 保存配置文件
    config_path = os.path.join(config_dir, f"{i}.yaml")
    with open(config_path, "w", encoding='utf-8') as f:
        yaml.dump(config, f, default_flow_style=False, allow_unicode=True)

print("OU过程的配置文件生成完成！")