from nesim.utils.json_stuff import dict_to_json
import os

config_folder = "./configs"
common_checkpoint_step = 30500

os.system(f"rm {config_folder}/*.json")

checkpoint_root = "../gpt_neo_125m/checkpoints/"
run_names = {
    # "ours_scale_0.1": "apply_nesim_every_n_steps_10_nesim_config_scale_0.1_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    "ours_scale_1": "apply_nesim_every_n_steps_10_nesim_config_scale_1_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    # "ours_scale_10": "apply_nesim_every_n_steps_10_nesim_config_scale_10_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    # "ours_scale_100": "apply_nesim_every_n_steps_10_nesim_config_scale_100_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    # "ours_scale_1000": "apply_nesim_every_n_steps_10_nesim_config_scale_1000_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    "baseline": "apply_nesim_every_n_steps_10_nesim_config_baseline_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
}

# checkpoint_step = {}
# for name in run_names:
#     checkpoint_step[name] = common_checkpoint_step

checkpoint_filenames = {
    # "pretrained": None
}

for key in run_names:
    filename = os.path.join(
        checkpoint_root,
        run_names[key],
        os.listdir(os.path.join(checkpoint_root, run_names[key]))[
            -1
        ],  ## auto select latest checkpoint
        "pytorch_model.bin",
    )
    assert os.path.exists(filename), f"""Invalid filename: {filename}"""
    checkpoint_filenames[key] = filename

## constants
constants = {
    "hf_model_name": "EleutherAI/gpt-neo-125m",
    "hf_tokenizer_name": "EleutherAI/gpt-neo-125m",
    "hf_dataset_name": "dair-ai/emotion",
    # "lora_r": 1,
    # "lora_alpha": 8,
    "lora_dropout": 0.05,
    "lora_bias": "none",
    "lora_task_type": "CAUSAL_LM",
    # "lora_target_modules": ["transformer.h.11.mlp.c_proj", "lm_head"],
    # "batch_size": 64,
    # "gradient_accumulation_steps": 1,
    # "warmup_steps": 100,
    "max_steps": 1000,
    # "learning_rate": 1e-3,
    "logging_steps": 1,
    "max_tokens_per_dataset_item": 1024,
}

all_q_layers = [f"transformer.h.{i}.attn.attention.q_proj" for i in range(12)]
all_v_layers = [f"transformer.h.{i}.attn.attention.v_proj" for i in range(12)]

## variable
all_params = {
    "lora_r": [1],
    "lora_alpha": [8],
    "lora_target_modules": [
        # {
        #     "name": "layer_10_and_11_c_proj",
        #     "modules": ["transformer.h.10.mlp.c_proj", "transformer.h.11.mlp.c_proj"]
        # },
        {"name": "all_q_and_v_proj_layers", "modules": all_q_layers + all_v_layers}
    ],
    "batch_size": [32],
    "gradient_accumulation_steps": [4],
    "warmup_steps": [400],
    "learning_rate": [1e-3],
}

count = 0
for lora_r in all_params["lora_r"]:
    for lora_alpha in all_params["lora_alpha"]:
        for lora_target_modules in all_params["lora_target_modules"]:
            for batch_size in all_params["batch_size"]:
                for gradient_accumulation_steps in all_params[
                    "gradient_accumulation_steps"
                ]:
                    for warmup_steps in all_params["warmup_steps"]:
                        for learning_rate in all_params["learning_rate"]:
                            for checkpoint_name in checkpoint_filenames:
                                config = constants
                                config["lora_r"] = lora_r
                                config["lora_alpha"] = lora_alpha
                                config["lora_target_modules"] = lora_target_modules[
                                    "modules"
                                ]
                                config["batch_size"] = batch_size
                                config[
                                    "gradient_accumulation_steps"
                                ] = gradient_accumulation_steps
                                config["checkpoint_filename"] = checkpoint_filenames[
                                    checkpoint_name
                                ]
                                config["checkpoint_name"] = checkpoint_name
                                config["warmup_steps"] = warmup_steps
                                config["learning_rate"] = learning_rate

                                # filename = os.path.join(
                                #     config_folder,
                                #     f"lora_r_{lora_r}_lora_alpha_{lora_alpha}_lora_target_modules_{lora_target_modules}_batch_size_{batch_size}_gradient_accumulation_steps_{gradient_accumulation_steps}.json"
                                # )

                                config[
                                    "name"
                                ] = f"lora_r_{lora_r}_lora_alpha_{lora_alpha}_lora_target_modules_{lora_target_modules['name']}_batch_size_{batch_size}_gradient_accumulation_steps_{gradient_accumulation_steps}_warmup_steps_{warmup_steps}_learning_rate_{learning_rate}"

                                filename = os.path.join(config_folder, f"{count}.json")

                                dict_to_json(config, filename=filename)
                                count += 1
print(f"Saved {count} configs")
