import json
import logging
from pathlib import Path

import h5py
import jax
import numpy as np
import tqdm.auto as tqdm
from scipy.special import logsumexp

# Import h5py types for clarity
from neural_pfaffian.logger import Logger
from neural_pfaffian.systems import Systems
from neural_pfaffian.vmc import VMC, VMCState


def evaluate(
    key: jax.Array,
    vmc: VMC,
    state: VMCState,
    systems: Systems,
    epochs: int,
    mcmc_steps: int,
    logger: Logger,
):
    vmc = vmc.replace(sampler=vmc.sampler.replace(steps=mcmc_steps))
    pbar = tqdm.trange(epochs)
    for _epoch in pbar:
        key, subkey = jax.random.split(key)
        # Eval step
        systems, eval_data = vmc.eval_step(
            subkey,
            state.sharded,
            systems.sharded,
        )

        # Convert JAX arrays to numpy arrays if needed
        energy = jax.device_get(eval_data.energy)
        logger.log_matrix('energy', energy, prefix='eval')
        energy_sum = jax.device_get(eval_data.energy_sum)
        logger.log_matrix('energy_sum', energy_sum, prefix='eval')
        energy_squared_sum = jax.device_get(eval_data.energy_squared_sum)
        logger.log_matrix('energy_squared_sum', energy_squared_sum, prefix='eval')

        # If overlap data is computed, log it separately along with the iteration number.
        if eval_data.overlap is not None:
            overlap = jax.device_get(eval_data.overlap)
            logger.log_matrix('overlap', overlap, prefix='eval')
        if eval_data.log_density_overlap is not None:
            log_density_overlap = jax.device_get(eval_data.log_density_overlap)
            logger.log_matrix('log_density_overlap', log_density_overlap, prefix='eval')
        if eval_data.log_abs_wf_overlap is not None:
            log_abs_wf_overlap = jax.device_get(eval_data.log_abs_wf_overlap)
            logger.log_matrix('log_abs_wf_overlap', log_abs_wf_overlap, prefix='eval')

        logger.reschedule_hook()

    logging.info('Evaluating done')

    directories = logger.log_directories()
    if not directories:
        raise ValueError('No log directory found.')
    path = Path(directories[0])

    logging.info('Computing summary')
    with h5py.File(path / 'eval.h5', 'r') as f:
        energy_data: np.ndarray = np.array(f['energy'], dtype=np.float64)
        energy_sum_data: np.ndarray = np.array(f['energy_sum'], dtype=np.float64)
        energy_squared_sum_data: np.ndarray = np.array(
            f['energy_squared_sum'],
            dtype=np.float64,
        )
        overlap_data: np.ndarray | None = (
            np.array(f['overlap'], dtype=np.float64) if 'overlap' in f else None
        )
        log_density_overlap_data: np.ndarray | None = (
            np.array(f['log_density_overlap'], dtype=np.float64)
            if 'log_density_overlap' in f
            else None
        )
        log_abs_wf_overlap_data: np.ndarray | None = (
            np.array(f['log_abs_wf_overlap'], dtype=np.float64)
            if 'log_abs_wf_overlap' in f
            else None
        )
    # Compute mean values over all iterations per molecule.
    mean_energy, mean_energy_std, mean_energy_std_error = _pooled_stats_from_sums(
        energy_data,
        energy_sum_data,
        energy_squared_sum_data,
        systems.electrons.shape[0],
    )

    # Compute mean pairwise overlap over all iterations.
    mean_overlap = overlap_data.mean(axis=0) if overlap_data is not None else None
    mean_density_overlap = None
    normalized_mean_density_overlap = None
    mean_abs_wf_overlap = None
    if log_abs_wf_overlap_data is not None:
        mean_log_abs_wf_overlap = logsumexp(log_abs_wf_overlap_data, axis=0) - np.log(
            log_abs_wf_overlap_data.shape[0],
        )
        mean_abs_wf_overlap = np.exp(mean_log_abs_wf_overlap)
    if log_density_overlap_data is not None:
        max_log = np.max(log_density_overlap_data, axis=0)
        mean_exp = np.exp(log_density_overlap_data - max_log).mean(axis=0)
        log_mean_density_overlap = np.log(mean_exp) + max_log

        mean_density_overlap = np.exp(log_mean_density_overlap)

        log_diag = np.diagonal(log_mean_density_overlap, axis1=-2, axis2=-1)
        log_normalized_mean_density_overlap = log_mean_density_overlap - 0.5 * (
            log_diag[..., :, None] + log_diag[..., None, :]
        )
        normalized_mean_density_overlap = np.exp(log_normalized_mean_density_overlap)

    charges_per_molecule = [
        [int(charge) for charge in molecule_charges]
        for molecule_charges in systems.charges
    ]
    mol_ids = [int(idx) for idx in systems.mol_ids]

    # Create a summary dictionary converting arrays to lists for JSON-serialization.
    summary = {
        'mean_energy': mean_energy.tolist(),
        'mean_energy_std': mean_energy_std.tolist(),
        'mean_energy_std_error': mean_energy_std_error.tolist(),
        'charges': charges_per_molecule,
        'mol_ids': mol_ids,
    }
    summary['mean_overlap'] = mean_overlap.tolist() if mean_overlap is not None else None
    summary['mean_density_overlap'] = (
        mean_density_overlap.tolist() if mean_density_overlap is not None else None
    )
    summary['normalized_mean_density_overlap'] = (
        normalized_mean_density_overlap.tolist()
        if normalized_mean_density_overlap is not None
        else None
    )
    summary['mean_abs_wf_overlap'] = (
        mean_abs_wf_overlap.tolist() if mean_abs_wf_overlap is not None else None
    )
    logger.log(summary, prefix='summary')
    # FIXME: Redundant -> Add some json capability to the logger
    with open(path / 'summary.json', 'w') as f:
        json.dump(summary, f, indent=4)

    return summary


def _pooled_stats_from_sums(
    energy_means: np.ndarray,
    energy_sum: np.ndarray,
    energy_sum_sq: np.ndarray,
    n_walkers_per_step: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute pooled mean/std/SEM using a numerically stable Welford merge."""
    means = np.asarray(energy_means, dtype=np.float64)
    sums = np.asarray(energy_sum, dtype=np.float64)
    sums_sq = np.asarray(energy_sum_sq, dtype=np.float64)
    n_b = np.asarray(n_walkers_per_step, dtype=np.float64)

    mean = np.zeros(sums.shape[1], dtype=np.float64)
    m2 = np.zeros_like(mean)
    n_total = 0.0

    for step in range(sums.shape[0]):
        sum_b = sums[step]
        sum_sq_b = sums_sq[step]
        mean_b = means[step]
        m2_b = sum_sq_b - (sum_b * sum_b) / n_b
        m2_b = np.maximum(m2_b, 0.0)
        if n_total == 0.0:
            mean = mean_b
            m2 = m2_b
            n_total = n_b
            continue

        delta = mean_b - mean
        n = n_total + n_b
        mean = mean + delta * (n_b / n)
        m2 = m2 + m2_b + (delta * delta) * (n_total * n_b / n)
        n_total = n

    m2 = np.maximum(m2, 0.0)
    if n_total <= 1:
        nan = np.full_like(mean, np.nan)
        return nan, nan, nan

    var = m2 / (n_total - 1.0)
    std = np.sqrt(var)
    std_error = std / np.sqrt(n_total)
    return mean, std, std_error


def parse_epochs_and_walkers_from_config(cfg: dict, systems: Systems) -> tuple[int, int]:
    epochs = cfg.get('epochs', -1)
    samples = cfg.get('total_samples_per_energy', -1)
    walkers_per_mol = systems.electrons.shape[0]
    if cfg.get('num_total_walker', -1) > 0:
        walkers_per_mol = cfg['num_total_walker'] // systems.n_mols
        # Make sure walkers_per_mol is divisible by number of devices
        walkers_per_mol = (walkers_per_mol // jax.device_count()) * jax.device_count()

    if epochs < 0:
        assert samples > 0
        epochs = samples // walkers_per_mol
    return epochs, walkers_per_mol
