import os
import yaml
import sys
import os
import torch
import json
import shutil
import datetime
import numpy as np
import logging
import pandas as pd
from typing import Any, Dict, List, Optional, Tuple, Set

def get_project_root() -> str:
    """Returns the root directory of the project."""
    return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

def get_path(levels, main_dir):
    path = main_dir
    for item in levels:
        path = os.path.join(path, item + "/")
        if not os.path.exists(path):
            os.mkdir(path)
    return path

def get_data_dir():
    data_path = os.path.join(get_project_root(), 'data')
    os.makedirs(data_path, exist_ok=True)
    return data_path

def get_results_dir():
    data_path = os.path.join(get_project_root(), 'results')
    os.makedirs(data_path, exist_ok=True)
    return data_path

def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


# JSON
def _json_default(x: Any) -> Any:
    if isinstance(x, (np.ndarray,)):
        return x.tolist()
    if isinstance(x, (np.floating, np.integer)):
        return x.item()
    if torch.is_tensor(x):
        return x.detach().cpu().tolist()
    return str(x)

def stable_json_dumps(obj: Any) -> str:
    return json.dumps(obj, sort_keys=True, ensure_ascii=False, default=_json_default)

def load_jsonl(path: str) -> List[Dict[str, Any]]:
    if not os.path.exists(path):
        return []
    out: List[Dict[str, Any]] = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                out.append(json.loads(line))
    return out

def append_jsonl(path: str, records: List[Dict[str, Any]]) -> None:
    """
    Append records to a JSON Lines file, flush + fsync for robustness.
    """
    ensure_dir(os.path.dirname(path))
    with open(path, "a", encoding="utf-8") as f:
        for r in records:
            f.write(stable_json_dumps(r) + "\n")
        f.flush()
        os.fsync(f.fileno())
        
def infer_next_indices(existing: List[Dict[str, Any]]) -> set[Tuple[int, int, int]]:
    """
    Build a set of completed (theta_idx, sim_idx, n) triples for resume.
    """
    done = set()
    for r in existing:
        done.add((int(r["theta_idx"]), int(r["sim_idx"]), int(r["n"])))
    return done


def set_reproducibility(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def atomic_write_bytes(path: str, data: bytes) -> None:
    tmp = path + ".tmp"
    with open(tmp, "wb") as f:
        f.write(data)
        f.flush()
        os.fsync(f.fileno())
    os.replace(tmp, path)

def atomic_save_npz(path: str, **arrays: Any) -> None:
    """
    Atomic npz write (temp then replace). Suitable for checkpoints.
    """
    tmp_path = path + ".tmp.npz"
    np.savez_compressed(tmp_path, **arrays)
    os.replace(tmp_path, path)

def configure_logger(log_path: str, verbose: bool) -> logging.Logger:
    logger = logging.getLogger("sim_runner")
    logger.setLevel(logging.DEBUG if verbose else logging.INFO)
    logger.handlers.clear()

    fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")

    sh = logging.StreamHandler(sys.stderr)
    sh.setLevel(logging.DEBUG if verbose else logging.INFO)
    sh.setFormatter(fmt)
    logger.addHandler(sh)

    fh = logging.FileHandler(log_path)
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(fmt)
    logger.addHandler(fh)

    return logger


def flush_records_to_csv(records: List[Dict[str, Any]], csv_path: str) -> None:
    if not records:
        return

    df = pd.DataFrame.from_records(records)

    # Store list-like fields as JSON strings so CSV stays sane
    for col in ["mmd_quantiles", "sink_quantiles", "cutoffs"]:
        if col in df.columns:
            df[col] = df[col].apply(lambda x: json.dumps(x, sort_keys=True) if isinstance(x, (list, dict)) else x)

    header = not os.path.exists(csv_path)
    df.to_csv(csv_path, mode="a", header=header, index=False)

def load_done_keys_from_csv(csv_path: str) -> Set[Tuple[int, int, int]]:
    """
    Read only the key columns from results.csv and return a set of completed triples.
    Robust to partially-written lines by skipping bad lines.
    """
    if not os.path.exists(csv_path):
        return set()

    # Read minimal columns to reduce memory / time
    usecols = ["theta_idx", "sim_idx", "n"]
    df = pd.read_csv(
        csv_path,
        usecols=usecols,
        on_bad_lines="skip",   # pandas>=1.3
        engine="c",
    )

    # Normalize dtypes; drop any rows with missing keys
    df = df.dropna(subset=usecols)
    df["theta_idx"] = df["theta_idx"].astype(int)
    df["sim_idx"] = df["sim_idx"].astype(int)
    df["n"] = df["n"].astype(int)

    return set(map(tuple, df[usecols].to_numpy()))