import pickle
from pathlib import Path

import jax
from jax import jit
from jax import numpy as np


@jit
def sigmoid_cross_entropy(logits, y):
    preds = jax.nn.log_sigmoid(logits)
    preds_inv = preds - logits
    return -np.mean(y * preds + (1 - y) * preds_inv)


@jit
def cross_entropy(logits, y):
    return -np.mean(np.sum(y * logits, axis=1))


@jit
def softmax_cross_entropy(logits, y):
    return cross_entropy(jax.nn.log_softmax(logits), y)


@jit
def mean_squared_error(logits, y):
    # Hopefully catch the error where the shapes don't match early rather than autoboadcasting
    assert logits.shape == y.shape, f"Not matching shapes {logits.shape} and {y.shape}"
    return np.mean((logits - y) ** 2)


def load(checkpoint):
    print(f"Loading checkpoint from {checkpoint}")
    with open(checkpoint, "rb") as f:
        state = pickle.load(f)
    return state


def load_latest(dir: Path):
    latest = max(dir.glob("checkpoint_*"))
    return load(dir / latest)


def save(dir: Path, state, step):
    dir.mkdir(parents=True, exist_ok=True)
    path = dir / f"checkpoint_{step:08d}.pkl"
    checkpoint_state = jax.device_get(state)
    print(f"Saving checkpoint to {path}")
    with open(path, "wb") as f:
        pickle.dump(checkpoint_state, f)
