"""
Utility functions for file and directory operations to enforce consistent naming
conventions and paths across the library.
"""

import hashlib
import os
import pickle
import shutil
import warnings
from os.path import exists, join
from pathlib import Path
from typing import Any, Dict, Literal, SupportsIndex

import numpy as np
import yaml

from egxc.utils.typing import AuxDataKey, MethodKey


def str_to_directory_name(string: str) -> str:
    """Convert a string to a directory name by replacing special characters."""
    return string.replace('(', '').replace(')', '').replace('-', '_').replace(' ', '_')


def hash_dictionary(dictionary: Dict[str, Any]) -> str:
    return hashlib.sha256(yaml.dump(dictionary).encode()).hexdigest()


def pickle_dictionary(dictionary: Dict[str, Any], path: str):
    with open(path, 'wb') as f:
        pickle.dump(dictionary, f)


def unpickle_dictionary(path: str) -> Dict[str, Any]:
    with open(path, 'rb') as f:
        return pickle.load(f)


def copy_file_to_directory(
    source: str | os.PathLike[str],
    destination_dir: str | os.PathLike[str],
) -> None:
    """Copy a file into a destination directory preserving its metadata."""
    src_path = Path(source)
    dest_dir = Path(destination_dir)
    dest_dir.mkdir(parents=True, exist_ok=True)
    dest_path = dest_dir / src_path.name
    shutil.copy2(src_path, dest_path)


def copy_directory_to_directory(
    source_dir: str | os.PathLike[str],
    destination_dir: str | os.PathLike[str],
    overwrite: bool = False,
) -> None:
    """Copy a directory into a destination directory preserving its metadata."""
    src_dir = Path(source_dir)
    dest_dir = Path(destination_dir)
    shutil.copytree(src_dir, dest_dir, dirs_exist_ok=overwrite)


def checkpoint_directory(
    root: str,
    method: str,
    basis: str,
    name: str,
    data_split_seed: SupportsIndex,
) -> str:
    """Returns the checkpoint directory.

    Args:
        root (str): The root directory.
        method (str): The method name.
        basis (str): The basis set name.
        name (str): The checkpoint name.
        data_split_seed (int): The data split seed.
    """
    method_dir = str_to_directory_name(method)
    basis_dir = str_to_directory_name(basis)
    name_dir = str_to_directory_name(name)
    split_dir = f'split_{int(data_split_seed)}'
    return join(root, method_dir, basis_dir, name_dir, split_dir)


def checkpoint_best_path(ckp_dir: str, prefix: str = '') -> str:
    """
    Returns the path for the best-checkpoint file.

    Args:
        ckp_dir (str): The directory containing the checkpoints.
        prefix (str): The prefix for the checkpoint file.
    """
    return join(ckp_dir, f'best_{prefix}_params.flax')


def checkpoint_step_path(ckp_dir: str, step: SupportsIndex, prefix: str = '') -> str:
    """
    Returns the path for a step checkpoint file.

    Args:
        ckp_dir (str): The directory containing the checkpoints.
        step (int): The step number.
        prefix (str): The prefix for the checkpoint file.
    """
    return join(ckp_dir, f'params_{prefix}_{int(step)}.flax')


def checkpoint_config_path(ckp_dir: str, name: str) -> str:
    """
    Returns the path for the checkpoint configuration YAML.

    Args:
        ckp_dir (str): The directory containing the checkpoints.
        name (str): The name of the checkpoint.
    """
    return join(ckp_dir, f'{str_to_directory_name(name)}.yaml')


def results_directory(
    base_dir: str | os.PathLike[str],
    name: str,
    exists_ok: bool = False,
) -> Path:
    """Ensure and return the run-specific results directory."""
    root = Path(base_dir).resolve()
    root.mkdir(parents=True, exist_ok=True)
    run_dir = root / str_to_directory_name(name)
    run_dir.mkdir(parents=True, exist_ok=exists_ok)
    return run_dir


def results_metrics_csv_path(
    run_dir: str | os.PathLike[str],
    csv_relative_path: str | None,
    default_filename: str = 'metrics.csv',
) -> Path:
    """Resolve the metrics CSV path within a run directory."""
    run_path = Path(run_dir)
    relative = Path(csv_relative_path) if csv_relative_path else Path(default_filename)
    csv_path = run_path / relative
    assert csv_path.parent.exists(), f'Parent directory {csv_path.parent} does not exist'
    return csv_path


def results_summary_path(
    run_dir: str | os.PathLike[str], filename: str = 'summary.txt'
) -> Path:
    """Return the metrics summary text file path."""
    return Path(run_dir) / filename


def write_to_npz(path: str, data: Dict, overwrite=False) -> None:
    if not exists(path) or overwrite:
        np.savez_compressed(path, **data)
    else:
        warnings.warn(f'File {path} already exists')


def auxiliary_data_directory(
    aux_dir: str,
    data_type: AuxDataKey,
    method_key: MethodKey,
    **kwargs: Any,
) -> str:
    """
    Get the directory for auxiliary data files.

    Args:
        aux_dir (str): The base auxiliary directory.
        data_type (str): The type of the auxiliary data.
        method_key (str): The method key used for calculations.
        kwargs: Additional method-specific modifiers that effect data generation.
    """
    aux_dir = join(aux_dir, data_type, method_key)
    match method_key:
        case 'ks_dft':

            def dft_modifiers_to_path_key_words(
                _aux_dir: str,
                basis: str,
                xc_str: str,
                use_eri_density_fitting: bool,
                use_exchange_density_fitting: bool,
                spin_restricted: bool,
                quadrature_grid_level: int,
                backend: Literal['pyscf', 'custom'],
                n_cycles: int | None = None,
            ) -> str:
                _aux_dir = join(_aux_dir, str_to_directory_name(basis))
                _aux_dir = join(_aux_dir, str_to_directory_name(xc_str))
                _aux_dir = join(
                    _aux_dir, 'eri_df' if use_eri_density_fitting else 'exact'
                )
                _aux_dir = join(
                    _aux_dir, 'x_df' if use_exchange_density_fitting else 'exact'
                )
                _aux_dir = join(_aux_dir, 'rks' if spin_restricted else 'uks')
                _aux_dir = join(_aux_dir, f'grid_lvl_{quadrature_grid_level}')
                _aux_dir = join(_aux_dir, f'with_{backend}')
                if n_cycles is not None:
                    _aux_dir = join(_aux_dir, f'n_cycles_{int(n_cycles)}')
                return _aux_dir

            aux_dir = dft_modifiers_to_path_key_words(aux_dir, **kwargs)
        case _:
            raise NotImplementedError(f'Method key {method_key} not implemented.')
    return aux_dir


def auxiliary_data_path(aux_dir: str, idx: SupportsIndex) -> str:
    """
    Returns the path for the auxiliary data file.
    """
    return join(aux_dir, f'{idx}.npz')


def auxiliary_data_save(
    aux_dir: str, idx: SupportsIndex, data: Dict, overwrite=False
) -> None:
    """
    Save auxiliary data to a file.

    Args:
        aux_dir (str): The directory to save the auxiliary data.
        idx (int): The index of the sample.
        data (Dict): The auxiliary data to save.
    """
    out_path = auxiliary_data_path(aux_dir, idx)
    write_to_npz(out_path, data, overwrite)


def auxiliary_data_exists(aux_dir: str, idx: SupportsIndex) -> bool:
    """
    Check if the auxiliary data exists and is valid by loading each array.
    If the file is corrupted, it is deleted and False is returned.

    Args:
        aux_dir (str): The directory containing the auxiliary data.
        idx (int): The index of the sample.

    Returns:
        bool: True if the auxiliary data exists and is valid, False otherwise.
    """
    file_path = auxiliary_data_path(aux_dir, idx)
    if not exists(file_path):
        print(f'[MISSING] {file_path} missing, recomputing...')
        return False
    else:
        try:
            with np.load(file_path, allow_pickle=True) as data:
                # Force load each array to detect corruption
                keys = list(data.files)
                # Remove 'compute_costs' if it exists, otherwise skip it
                if 'compute_costs' in keys:
                    keys.remove('compute_costs')
                for arr_name in keys:
                    _ = data[arr_name]
        except Exception as e:
            print(f'[CORRUPT] Error loading {file_path}: {e}. Deleting file...')
            os.remove(file_path)
            return False
        return True
