import numpy as np

STORAGE_DIR = "/storage1/fs1/XXXX-1/Active/chess"
DATA_DIR = f"{STORAGE_DIR}/data/csv_data_z3"
MODEL_DIR = f"{STORAGE_DIR}/leela/trained-models"
CFG_DIR = "/home/XXXX-3/leela/lczero-training/tf/nets/configs/"
SF_PATH = "/home/XXXX-3/leela/stockfish17_popcnt/stockfish-ubuntu-x86-64-sse41-popcnt"


def _z_to_wp(z: np.ndarray) -> float:
    return z[0] + z[1] / 2


def _softmax(a: np.ndarray, temp=1.0):
    if temp == 0:
        max_vals = np.max(a, axis=-1, keepdims=True)
        a = (a == max_vals).astype(np.float32)
        a = a / np.sum(a, axis=-1, keepdims=True)
        return a
    a = a - np.max(a, axis=-1, keepdims=True)
    numerator = np.exp(a / temp)
    a = numerator / np.sum(numerator, axis=-1, keepdims=True)
    return a


def import_tf():
    import os
    import logging
    import warnings

    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["TF_KERAS_BACKEND_DISABLE_WARNINGS"] = "1"
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
    logging.getLogger("tensorflow").setLevel(logging.ERROR)

    # Specifically filter the VarianceScaling warning
    warnings.filterwarnings("ignore", message=".*VarianceScaling is unseeded.*")

    # Monkey patch the warning method in Keras to silence this specific warning
    import keras.src.initializers.initializers as keras_init

    original_warning = keras_init.warnings.warn

    def filtered_warning(message, *args, **kwargs):
        if "VarianceScaling is unseeded" not in str(message):
            original_warning(message, *args, **kwargs)

    keras_init.warnings.warn = filtered_warning

    import tensorflow as tf

    tf.get_logger().setLevel("ERROR")
    tf.autograph.set_verbosity(0)

    from tensorflow.python.util import deprecation

    deprecation._PRINT_DEPRECATION_WARNINGS = False
    os.environ["CUDA_CACHE_DISABLE"] = "1"
    return tf


class EarlyStopper:
    def __init__(self, patience=None, minimize=True):
        self.patience = patience
        self.wait = 0
        self.best_value = None
        self.comparator = lambda value, best_value: (
            value < best_value if minimize else value > best_value
        )
        self.is_finished = False
        self.at_best = False

    def should_stop(self, value):
        if self.is_finished:
            return True
        if self.best_value is None or self.comparator(value, self.best_value):
            self.best_value = value
            self.wait = 0
            self.at_best = True
        else:
            self.wait += 1
            self.at_best = False
        if self.patience is not None and self.wait >= self.patience:
            self.is_finished = True
