#!/usr/bin/env python3

import os
import traceback

import numpy as np
from jax import config
from reference_densities import DEFAULT_DATA_KWARGS, ExperimentWrapper, get_dataset
from tqdm import tqdm

from egxc.dataloading import io

config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'cpu')

# Configuration: edit these constants as needed
DATA_DIR = 'ANONYMOUS_DIR'
HEAVY_ATOMS_THRESH = 7
BASIS = 'def2-SVP'
METHOD = 'dft_lda'
N_CHUNKS = 1
CHUNK_ID = 0


def main():
    # Build dataset configuration
    data_conf = (
        'qm9',
        {
            'data_dir': DATA_DIR,
            'heavy_atoms_thresh': HEAVY_ATOMS_THRESH,
            **DEFAULT_DATA_KWARGS,
        },
    )
    # Prepare dataset and computation wrapper
    dataset = get_dataset(data_conf)
    wrapper = ExperimentWrapper(BASIS, METHOD, N_CHUNKS, CHUNK_ID)

    # Determine auxiliary data directory
    aux_dir = io.auxiliary_data_paths(
        dataset.auxiliary_data_directory, METHOD, BASIS, 'initial_guess'
    )
    os.makedirs(aux_dir, exist_ok=True)

    total = len(dataset)
    fixed = 0

    for sample in tqdm(dataset):  # type: ignore
        idx = sample[0]
        file_path = os.path.join(aux_dir, f'{idx}.npz')

        needs_recompute = False
        if not os.path.exists(file_path):
            print(f'[MISSING] {file_path} missing, recomputing...')
            needs_recompute = True
        else:
            try:
                with np.load(file_path, allow_pickle=True) as data:
                    # Force load each array to detect corruption
                    for arr_name in data.files:
                        _ = data[arr_name]
            except Exception as e:
                print(f'[CORRUPT] Error loading {file_path}: {e}. Recomputing...')
                needs_recompute = True

        if needs_recompute:
            try:
                idx_new, density, energy = wrapper.compute(sample)
                # Remove old file if it exists to ensure overwrite
                if os.path.exists(file_path):
                    os.remove(file_path)
                io.save_aux_data(aux_dir, idx_new, density, energy)
                print(f'[FIXED] Recomputed and saved density sample for index {idx_new}')
                fixed += 1
            except Exception as e:
                print(f'[ERROR] Failed to recompute for index {idx}: {e}')
                traceback.print_exc()

    print(f'Validation completed: {fixed} files fixed out of {total} samples.')


if __name__ == '__main__':
    main()
