"""Manage paths / filenames for experiments, datasets, results, and logs."""

import torch
from torch import Tensor
import os
import os.path as osp
from utils.log import get_timestamp
from utils.config import SPLITS
from typing import Optional

ROOT = ""
DATASETS_PATH = os.path.join(ROOT, "datasets")
RESULT_PATH = os.path.join(ROOT, "results")
LOG_PATH = os.path.join(ROOT, "logs")
WANDB_PATH = os.path.join(ROOT, "wandb")


def get_split_dataset_path(split: str):
    """Get datasets path based on dataset split type: `DATASETS_PATH/split`"""
    if split in SPLITS:
        return os.path.join(DATASETS_PATH, split)
    else:
        raise ValueError(
            f"Unsupported split: {split}. "
            "Supported types are 'train', 'validation', and 'test'."
        )


def get_exp_path(model_name: str, expid: str) -> str:
    """Get experiment path: `RESULT_PATH/model_name/expid`"""
    return osp.join(RESULT_PATH, model_name, expid)


def _get_result_subpath(model_name, expid, task_type, suffix: Optional[str] = None):
    path = osp.join(model_name, expid, task_type)
    if suffix is not None:
        path = osp.join(path, suffix)
    return path


def get_result_plot_path(
    model_name: str, expid: str, task_type: str, suffix: Optional[str] = None
) -> str:
    """Get path for result plots:
    `RESULT_PLOT_PATH/model_name/expid/task_type/suffix`
    """
    subpath = _get_result_subpath(
        model_name=model_name,
        expid=expid,
        task_type=task_type,
        suffix=suffix,
    )

    path = os.path.join(RESULT_PATH, "plots", subpath)
    return path


def get_result_data_path(
    model_name: str, expid: str, task_type: str, suffix: Optional[str] = None
) -> str:
    """Get path for result data:
    `RESULT_DATA_PATH/model_name/exp_id/task_type/suffix`
    """
    subpath = _get_result_subpath(
        model_name=model_name,
        expid=expid,
        task_type=task_type,
        suffix=suffix,
    )
    path = os.path.join(RESULT_PATH, "data", subpath)
    return path


def get_log_filepath(
    group_name: str,
    expid: str,
    prefix: Optional[str] = None,
    logid: Optional[str] = None,
    suffix: Optional[str] = None,
) -> str:
    """Get log file path: `LOG_PATH/group_name/expid/prefix_log_id_suffix.log`"""
    log_dir = osp.join(LOG_PATH, group_name, expid)
    os.makedirs(log_dir, exist_ok=True)

    prefix_ = f"{prefix}_" if prefix else ""
    logid = f"{logid}" if logid else get_timestamp()
    _suffix = f"_{suffix}" if suffix else ""

    filename = f"{prefix_}{logid}{_suffix}.log"
    log_filename = osp.join(log_dir, filename)

    return log_filename


def get_pred_filename(
    nc, nt, valid_x_counts: Tensor | int, valid_y_counts: Tensor | int
) -> str:
    """Generate figure name for prediction"""

    def _count_to_str(count: Tensor | int) -> str:
        if isinstance(count, Tensor):
            return "".join(str(c) for c in torch.unique(count).tolist())
        else:
            return str(count)

    x_counts_str = _count_to_str(valid_x_counts)
    y_counts_str = _count_to_str(valid_y_counts)

    params = {"x": x_counts_str, "y": y_counts_str, "nc": nc, "nt": nt}
    name = params_to_filename(params)

    return name


def _convert_value_to_str(val):
    """Convert a value to a string representation."""
    if isinstance(val, str):
        val_str = val
    elif val is None:
        val_str = "none"
    elif isinstance(val, bool):
        val_str = "1" if val else "0"
    elif isinstance(val, int):
        # NOTE must be after bool
        val_str = str(val)
    elif isinstance(val, float):
        # 0.04 -> "4e02"
        val_str = f"{val:.0e}"
        val_str.replace("-", "")
        val_str = val_str.replace(".", "")
    elif isinstance(val, list):
        # ["a", "b", "c"] -> "abc"
        val_str = "".join(_convert_value_to_str(v) for v in val)
    else:
        raise TypeError(
            f"Unsupported type {type(val)} for value '{val}'. "
            "Supported types are None, bool, float, and str."
        )

    # Remove spaces and slashes from the string
    val_str = val_str.replace(" ", "")
    val_str = val_str.replace("/", "")

    return val_str


def params_to_filename(
    params: dict,
    preffix: dict = {},
    suffix: dict = {},
    params_map: dict = {},
    exclude: list[str] = [],
) -> str:
    """Generate a filename string summarizing config attributes.

    Examples:
        `Grid1_T64_gs1_fq1_fp1_tb0_rt_norm_ratio_n0_1_sg0_0e+00_t32`
    """

    def _kv2str(key, val):
        """Convert key and value to a string representation, if key is not empty."""
        if key == "":
            return None
        return f"{key}{val}"

    def _append(cur_list: list, new_dict: dict, params_map: dict = {}) -> list:
        """Create string items from `new_dict` and append to `cur_list`."""
        if new_dict is None or not new_dict:
            return cur_list
        if params_map is None:
            params_map = {}

        for key, val in sorted(new_dict.items()):
            if key in exclude:
                continue

            key_str = params_map.get(key, key[:4])
            val_str = _convert_value_to_str(val)
            item_str = _kv2str(key_str, val_str)

            if item_str is not None:
                cur_list.append(item_str)

        return cur_list

    exclude = set(exclude or [])

    item_list = []
    item_list = _append(item_list, preffix)
    item_list = _append(item_list, params, params_map)
    item_list = _append(item_list, suffix)

    return "_".join(item_list)
