import os
import shutil
import copy
import json

base_config = {
    "dataset_config_name": [
        "en"
    ],
    "delta_type": "adapter",
    "do_eval": True,
    "do_test": True,
    "do_train": True,
    "eval_dataset_config_name": [
        "en"
    ],
    "eval_dataset_name": "cola",
    "eval_steps": 200,
    "evaluation_strategy": "steps",
    "greater_is_better": True,
    "learning_rate": 0.0003,
    "max_grad_norm": 0.1,
    "load_best_model_at_end": True,
    "lora_alpha": 8,
    "lora_r": 8,
    "logging_steps": 100,
    "max_seq_length": 128,
    "metric_for_best_model": "eval_matthews_correlation",
    "model_name": "roberta",
    "model_name_or_path": "roberta-base",
    "non_linearity": "gelu_new",
    "num_train_epochs": 20,
    "output_dir": "outputs/adapter/roberta-base/cola",
    "overwrite_output_dir": True,
    "per_device_eval_batch_size": 32,
    "per_device_train_batch_size": 32,
    "predict_with_generate": True,
    "push_to_hub": False,
    "save_steps": 200,
    "save_strategy": "steps",
    "save_total_limit": 1,
    "split_validation_test": True,
    "task_name": "cola",
    "test_dataset_config_name": [
        "en"
    ],
    "test_dataset_name": "cola",
    "tokenizer_name": "roberta-base",
    "unfrozen_modules": [
        "classifier",
        "deltas"
    ],
    "warmup_ratio": 0.06,
    "warmup_steps": 0,
    "weight_decay": 0.1,
    "disable_tqdm": True,
    "report_to": None,
    "run_name": "",
    "local_rank": 0,
    "ddp_find_unused_parameters": False
}

tasks = ["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb"]
batch_size = [32] * len(tasks)
metrics = ["matthews_correlation"] + ["accuracy"] * 6 + ["pearson"]
epochs = [50, 50, 50, 20, 50, 50, 20, 50]
deltas = ["adapter", "lora"]
# deltas = ["none"]

optimizers = ["adamw", "kfac"]
model_names = ["roberta"]
model_paths = ["roberta-large"]


for delta in deltas:
    if not os.path.exists(delta):
    #     shutil.rmtree(delta)
        os.mkdir(delta)
    for model, model_path_full in zip(model_names, model_paths):
        model_path = model_path_full.split("/")[-1]
        if not os.path.exists(f"{delta}/{model_path}"):
            os.mkdir(f"{delta}/{model_path}")
        for op in optimizers:
            if not os.path.exists(f"{delta}/{model_path}/{op}"):
                os.mkdir(f"{delta}/{model_path}/{op}")

            for task, bs, metric, epoch in zip(tasks, batch_size, metrics, epochs):
                config = copy.deepcopy(base_config)
                config["model_name"] = model 
                config["per_device_eval_batch_size"] = bs
                config["per_device_train_batch_size"] = bs
                config["model_name_or_path"] = model_path_full
                config["tokenizer_name"] = model_path_full
                config['delta_type'] = delta
                config['output_dir'] = f"outputs/{delta}/{model_path}/{task}"
                config['task_name'] = task
                config['test_dataset_name'] = task
                config['eval_dataset_name'] = task
                config["optimizer"] = op if op != "adamw" else None
                config["metric_for_best_model"] = f"eval_{metric}"
                # wandb run_name
                config["run_name"] = f"{task}-{delta}-{model_path}-{op}"
                config["output_dir"] = f"outputs/{delta}/{model_path}/{task}/{op}"
                config["num_train_epochs"] = epoch

                with open(f"{delta}/{model_path}/{op}/{task}.json", 'w')as f:
                    f.writelines(json.dumps(config, indent=4))
