import os
from dataclasses import dataclass


@dataclass
class RunPaths:
    experiment_id: str
    run_directory: str
    checkpoint_directory: str
    log_directory: str
    graphs_directory: str


def get_experiment_id(args):
    model_name = "Supervised" if args.supervised else args.model_name

    ep_str = f'ep{str(args.epochs)}'
    lr_str = f'lr{str(args.initial_lr).replace(".", "p")}'
    bs_str = f'bs{str(args.batch_size)}'
    wd_str = f'wd{str(args.weight_decay).replace(".", "p")}'
    proj_dim_str = f'proj_dim{str(args.projection_dim)}'
    proj_layer_str = f'proj_layer{str(args.projection_layer)}'
    temp_str = f'temp{str(args.temperature).replace(".", "p")}'
    pred_dim_str = f'pred_dim{str(args.prediction_dim)}'
    pred_layer_str = f'pred_layer{str(args.prediction_layer)}'
    tau_str = f'tau{str(args.tau).replace(".", "p")}'
    seed_str = f'seed{args.seed}' if args.seed is not None else ""

    parts = [
        model_name,
        args.dataset,
        args.architecture,
        ep_str,
        lr_str,
        bs_str,
        args.optimizer,
        wd_str,
        proj_dim_str,
        proj_layer_str,
    ]

    if args.model_name in ["SimCLR", "SDMI", "MoCo", "SimSiam_SDMI", "JMI"]:
        parts.append(temp_str)

    if args.model_name in ["BYOL", "MoCo", "SimSiam"]:
        parts.append(pred_dim_str)
        parts.append(pred_layer_str)

    if args.model_name in ["BYOL", "MoCo"]:
        parts.append(tau_str)

    if seed_str:
        parts.append(seed_str)

    experiment_id = "_".join(parts)

    return experiment_id


def get_run_paths(args, run_number):
    experiment_id = get_experiment_id(args)
    run_directory = os.path.join('checkpoints', experiment_id, f'run-{run_number}')

    checkpoint_directory = os.path.join(run_directory, 'models')
    log_directory = os.path.join(run_directory, 'logs')
    graphs_directory = os.path.join(run_directory, 'graphs')

    os.makedirs(checkpoint_directory, exist_ok=True)
    os.makedirs(log_directory, exist_ok=True)
    os.makedirs(graphs_directory, exist_ok=True)

    return RunPaths(
        experiment_id=experiment_id,
        run_directory=run_directory,
        checkpoint_directory=checkpoint_directory,
        log_directory=log_directory,
        graphs_directory=graphs_directory,
    )
