import os
from pathlib import Path

import numpy as np
import pandas as pd
from omegaconf import DictConfig, OmegaConf


def flatten_args(args: DictConfig):
    flattened_args_dict = {}
    for key, val in args.items():
        if isinstance(val, DictConfig):
            v_args = flatten_args(val)
            flattened_args_dict.update(
                {f"{key}.{v_key}": v_val
                 for v_key, v_val in v_args.items()})
        else:
            flattened_args_dict[key] = val
    return flattened_args_dict


def multiseed_train():
    args = OmegaConf.create({
        "exp_type": "p2a",
        "max_dataset_size": 2_000_000,
        "maze_type": "medium",
        "n": 1,
        "goal": 4,
        "device_": 0,
        "config": "ours/config/${exp_type}.yaml",
        "experiment_name": "${exp_type}_${maze_type}_${goal}",
        "root_dir": "ours/results/${experiment_name}",
        "inference_task_ids": ["${goal}"],
    })
    args = OmegaConf.merge(args, OmegaConf.from_cli())
    OmegaConf.resolve(args)

    if args.maze_type == "umaze" and args.exp_type == "p2m":
        args.max_dataset_size = 1000000
    if args.maze_type == "medium" and args.exp_type == "p2m":
        args.max_dataset_size = 1000000
    if args.maze_type == "umaze" and args.exp_type == "p2a":
        args.max_dataset_size = 1000000
    if args.maze_type == "medium" and args.exp_type == "p2a":
        args.max_dataset_size = 2000000

    config = OmegaConf.load(args.config)
    config = OmegaConf.merge(config, args)
    root_dir = Path(args.root_dir)
    root_dir.mkdir(parents=True, exist_ok=True)
    experiment_config_path = root_dir / "config.yaml"
    OmegaConf.save(config, experiment_config_path)

    for i in range(args.n):
        args_ = [f"{key}={val}" for key, val in flatten_args(args).items()]
        command = [
            f"CUDA_VISIBLE_DEVICES={args.device_}", "python", "ours/main.py",
            f"config={experiment_config_path}", *args_
        ]
        os.system(" ".join(command))

    # Aggregate results
    dfs = []
    for path in root_dir.glob("*"):
        if path.is_dir():
            adapt_eval_log = pd.read_csv(path / "adapt/eval.csv")
            dfs.append(adapt_eval_log)

    epoch = dfs[0]["epoch"].values
    target_success_rates = np.concatenate(
        [df["target_success_rate"].values[None] for df in dfs])
    mean = np.mean(target_success_rates, axis=0)
    std = np.std(target_success_rates, axis=0)

    with open(root_dir / "result.csv", "w", encoding="utf-8") as f:
        f.write("epoch,mean,std\n")
        for i, _ in enumerate(epoch):
            f.write(f"{epoch[i]},{mean[i]},{std[i]}\n")


if __name__ == "__main__":
    multiseed_train()
