import os
from pathlib import Path

import numpy as np
import pandas as pd
from omegaconf import OmegaConf


def multiseed_train():
    args = OmegaConf.create({
        "exp_type": "p2a",
        "maze_type": "medium",
        "goal": 4,
        "n": 1,
        "device_": 0,
        "exp_stem": "dail",
        "max_dataset_size": 2_000_000,
        "config": "dail/config/${exp_type}.yaml",
        "experiment_name": "${exp_stem}_${exp_type}_${maze_type}_${goal}",
        "root_dir": "dail/results/${experiment_name}",
        "inference_task_ids": ["${goal}"],
    })
    args = OmegaConf.merge(args, OmegaConf.from_cli())
    OmegaConf.resolve(args)

    if args.maze_type == "umaze" and args.exp_type == "p2m":
        args.max_dataset_size = 1000000
    if args.maze_type == "medium" and args.exp_type == "p2m":
        args.max_dataset_size = 1000000
    if args.maze_type == "umaze" and args.exp_type == "p2a":
        args.max_dataset_size = 1000000
    if args.maze_type == "medium" and args.exp_type == "p2a":
        args.max_dataset_size = 2000000

    config = OmegaConf.load(args.config)
    del args["config"]
    config = OmegaConf.merge(config, args)
    root_dir = Path(args.root_dir)
    root_dir.mkdir(parents=True, exist_ok=True)
    experiment_config_path = root_dir / "config.yaml"
    OmegaConf.save(config, experiment_config_path)

    for i in range(args.n):
        args_ = [f"{key}={val}" for key, val in args.items()]
        print(args_)
        command = [
            "OMP_NUM_THREADS=4",
            f"CUDA_VISIBLE_DEVICES={args.device_}",
            "python",
            "dail/main.py",
            f"config={experiment_config_path}",
            *args_,
        ]
        os.system(" ".join(command))

    # Aggregate results
    dfs = []
    for path in root_dir.glob("*"):
        if path.is_dir():
            adapt_eval_log = pd.read_csv(path / "adapt_eval.csv")
            dfs.append(adapt_eval_log)

    epoch = dfs[0]["epoch"].values
    target_success_rates = np.concatenate(
        [df["target_success_rate"].values[None] for df in dfs])
    mean = np.mean(target_success_rates, axis=0)
    std = np.std(target_success_rates, axis=0)

    with open(root_dir / "result.csv", "w") as f:
        f.write("epoch,mean,std\n")
        for i in range(len(epoch)):
            f.write(f"{epoch[i]},{mean[i]},{std[i]}\n")


multiseed_train()
