import os
import json
import dataclasses
import datetime
from copy import deepcopy
from typing import Any, Literal, Optional
from dacite import from_dict, Config


def get_list_of_runs(subdir: str = "", skip_incomplete: bool = True) -> list[str]:
    """Return list of run IDs present in the runs folder

    :param subdir: _description_, defaults to ""
    :type subdir: str, optional
    :param skip_incomplete: Whether to skip runs that don't have all the files, defaults to True
    :type skip_incomplete: bool, optional
    :return: _description_
    :rtype: list[str]
    """
    
    # Path from which to start looking for runs
    dir = os.path.join("data/runs/", subdir)

    # Get list of run IDs
    run_ids = [f.name.split("run-")[1] for f in os.scandir(dir) if f.is_dir() and "run-" in f.name]

    # Skip incomplete runs if requested
    if skip_incomplete:
        run_ids = [
            run_id for run_id in run_ids 
            if os.path.exists(os.path.join(dir, f"run-{run_id}", "bias-eval-results.json"))
        ]

    return run_ids


def complete_path(filepath: str) -> str:
    """
    Complete the path to the data directory and the run folder.
    """
    if os.path.exists(filepath):
        return filepath

    if os.environ.get("RUN_ID", None) is None:
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S")
        suffix = os.environ.get("RUN_ID_SUFFIX", "")
        print(f"No run ID provided. Setting it to {timestamp}{suffix}.")
        os.environ["RUN_ID"] = timestamp + suffix

    run_id = os.environ["RUN_ID"]

    if run_id not in filepath:
        filepath = os.path.join(f"run-{run_id}", filepath)

    if "runs/" not in filepath:
        filepath = os.path.join("runs", filepath)

    if "data/" not in filepath:
        filepath = os.path.join("data", filepath)

    return filepath


def dump_file(
    filepath: str,
    data: Any,
    write_mode: Literal["w", "a"] = "w",
    indent: Optional[int] = 2,
) -> None:
    """
    Dump a JSON object or dataclass instance to the data directory.
    """
    data = deepcopy(data)
    filepath = complete_path(filepath)
    os.makedirs(os.path.dirname(filepath), exist_ok=True)

    if dataclasses.is_dataclass(data):
        data = dataclasses.asdict(data)

    if isinstance(data, list):
        for i in range(len(data)):
            if dataclasses.is_dataclass(data[i]):
                data[i] = dataclasses.asdict(data[i])

    if write_mode == "a" and not os.path.exists(filepath):
        write_mode = "w"

    with open(filepath, write_mode) as f:
        json.dump(data, f, indent=indent)
        if write_mode == "a":
            f.write("\n")


def load_file(filepath: str, schema: type | None = None) -> Any:
    """
    Load a JSON file or dataclass instance from the data directory.
    """
    filepath = complete_path(filepath)

    if not os.path.exists(filepath):
        raise FileNotFoundError(f"File not found: {filepath}")

    if schema is None:
        with open(filepath, "r") as file:
            data = json.load(file)
    elif dataclasses.is_dataclass(schema):
        with open(filepath, "r") as file:
            data = json.load(file)
            config = Config(check_types=False)

            if isinstance(data, list):
                data = [
                    from_dict(data_class=schema, data=item, config=config)
                    for item in data
                ]
            else:
                data = from_dict(data_class=schema, data=data, config=config)
    else:
        raise ValueError(f"Invalid schema: {schema}")

    return data


def extract_json_from_str(s: str) -> Any:
    """
    Robustly extract JSON object from a string, after a wide range of sanitization operations.

    :param s: The string to extract JSON object from. It could be, for example, generation by an LLM.
    :type s: str

    :return: The extracted JSON object. None upon failure.
    """
    # Strip leading/trailing whitespace and formatting characters (```, ```json, etc.)
    s = s.replace("```json", "```")
    if "```" in s:
        if s.startswith("```") and s.endswith("```"):
            s = s[3:-3]
        elif s.count("```") != 2:
            return None
        else:
            s = s.split("```")[1]

    s = s.strip()

    if not s:
        return None

    try:
        return json.loads(s)
    except json.JSONDecodeError as e:
        print(f"Failed to extract JSON from string: {e}")
        return None
