import yaml
import os

# 定义所有CIR参数组合
# 格式: (kappa1, theta1, sigma1, kappa2, theta2, sigma2)
params_list = [
    (0.1, 0.050, 0.05, 0.2, 0.500, 0.10),
    (0.1, 0.100, 0.01, 1.0, 0.300, 0.10),
    (0.1, 0.100, 0.10, 0.5, 0.200, 0.05),
    (0.1, 0.100, 0.10, 1.0, 0.200, 0.05),
    (0.1, 0.200, 0.05, 0.5, 0.400, 0.10),
    (0.1, 0.300, 0.01, 0.5, 0.500, 0.05),
    (0.1, 0.300, 0.10, 0.1, 0.500, 0.01),
    (0.1, 0.300, 0.15, 0.2, 0.500, 0.05),
    (0.1, 0.500, 0.01, 0.5, 0.050, 0.10),
    (0.2, 0.050, 0.05, 1.0, 0.300, 0.05),
    (0.2, 0.100, 0.01, 1.0, 0.400, 0.01),
    (0.2, 0.100, 0.10, 0.2, 0.400, 0.05),
    (0.2, 0.300, 0.01, 2.0, 0.100, 0.01),
    (0.2, 0.300, 0.05, 0.1, 0.050, 0.05),
]

# 基础配置
base_config = {
    "input_dim": 1,
    "output_dim": 1,
    "input_len": 100,
    "output_len": 50,
    "d_model": 256,
    "n_layer": 4,
    "batch_size": 2048,
    "lr": 1.0e-4,
    "num_epochs": 100,
    "train_ratio": 0.8,
    "val_ratio": 0.1,
    "data": {
        "type": "cir",  # 使用CIR过程
        "num_samples": 50000,
        "save_dir": "./data"
    },
    "save": {
        "save_dir": "./checkpoints",
        "log_dir": "./logs",
        "keep_models": 3
    },
    "device": "cuda:0",
    "test_gpu": 0
}

# 创建配置目录
config_dir = "cir_two_sde"
os.makedirs(config_dir, exist_ok=True)

# 生成配置文件
for i, (kappa1, theta1, sigma1, kappa2, theta2, sigma2) in enumerate(params_list, 1):
    config = base_config.copy()
    config["data"] = config["data"].copy()  # 深拷贝data字典
    
    # 添加CIR过程的参数
    config["data"]["K1"] = kappa1
    config["data"]["theta1"] = theta1
    config["data"]["sigma1"] = sigma1
    config["data"]["K2"] = kappa2
    config["data"]["theta2"] = theta2
    config["data"]["sigma2"] = sigma2
    
    # 设置模型名称
    model_name = f"cir_dist_kappa1={kappa1}_theta1={theta1}_sigma1={sigma1}_kappa2={kappa2}_theta2={theta2}_sigma2={sigma2}"
    config["save"]["model_name"] = model_name
    
    # 保存配置文件
    config_path = os.path.join(config_dir, f"{i}.yaml")
    with open(config_path, "w", encoding='utf-8') as f:
        yaml.dump(config, f, default_flow_style=False, allow_unicode=True)

print("CIR过程的配置文件生成完成！")