"""
Utilities for running the experiments.
"""

from typing import Any

import os
import csv
import json
import numpy as np
import random
import torch

from argparse import Namespace


def seed_everywhere(seed: int) -> None:
    """
    Seed all the random number generators to ensure reproducibility.

    Args:
        seed (int): The seed value.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def convert_to_list(data: dict) -> dict[str:Any]:
    """
    Converts a elements of dictionary to a dictionary of lists.

    Args:
        data (dict): The dictionary to convert to a dictionary of lists.

    Returns:
        dict: The dictionary of lists.
    """
    if type(data) != dict and type(data) == np.ndarray:
        return data.tolist()
    elif type(data) != dict:
        return data

    for key in data.keys():
        if type(data[key]) == np.ndarray:
            data[key] = data[key].tolist()
        if type(data[key]) == dict:
            data[key] = convert_to_list(data[key])
        if type(data[key]) == list:
            data[key] = [convert_to_list(d) for d in data[key]]
        if type(data[key]) == tuple:
            data[key] = tuple([convert_to_list(d) for d in data[key]])
    return data


def save_args(args: Namespace, save_path: str) -> None:
    """
    Save the arguments to the specified directory.

    Args:
        args (Namespace): The namespace containing the arguments.
        save_path (str): The directory to save the arguments.
    """
    os.makedirs(save_path, exist_ok=True)

    with open(f"{save_path}/args.json", "w") as f:
        json.dump(vars(args), f, indent=2)


def save_json(metrics: dict, save_path: str) -> None:
    """
    Saves the metrics to the specified directory.
    Converts numpy arrays to lists before saving.

    Args:
        metrics (dict[str, list[float]]): The dictionary of metrics.
        save_path (str): The directory to save the metrics.
    """
    os.makedirs(save_path, exist_ok=True)
    metrics = convert_to_list(metrics)

    with open(f"{save_path}/metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)


def save_csv(metrics: dict, save_path: str, tag: str = "") -> None:
    """
    Save the metrics to the specified directory.

    Args:
        metrics (dict[str, list[float]]): The dictionary of metrics.
        save_path (str): The directory to save the metrics.
        tag (str): The tag to append to the file name.
    """
    os.makedirs(save_path, exist_ok=True)
    headers = list(metrics.keys())

    # Transpose the lists to form rows
    rows = zip(*(metrics[key] for key in headers))

    with open(f"{save_path}/eval_logs{tag}.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(headers)
        writer.writerows(rows)
