"""Simplified energy-only evaluation script for comparing model predictions against dataset labels.

This script evaluates trained DEI-XC checkpoints by running SCF and comparing predicted
total energies against original dataset labels (e.g., CCSD(T) for MD17).
"""

import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple

import grain.python as grain
import jax
import jax.numpy as jnp
import pandas as pd
from jax import config as jax_config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from egxc import dataloading, discretization, systems, xc_energy
from egxc.dataloading.dataloader import get_psys_and_dataloaders
from egxc.dataloading.datasets.base import Targets
from egxc.dataloading.datasets.ensemble import DatasetEnsemble
from egxc.dataloading.io import (
    checkpoint_best_path,
    copy_directory_to_directory,
    results_directory,
    results_metrics_csv_path,
    results_summary_path,
)
from egxc.solver.scf import SelfConsistentFieldSolver
from egxc.utils.typing import Alignment, BaseInitialGuess

compilation_cache.set_cache_dir('./caches/jax_compile/deixc_eval_energy')

jax_config.update('jax_platform_name', 'gpu')
jax_config.update('jax_enable_x64', True)
jax_config.update('jax_default_matmul_precision', 'float32')

ex = Experiment()

CHECKPOINT_ROOT = 'ANONYMOUS_DIR'


@ex.config
def default_config():
    """Default Sacred/SEML configuration for energy-only evaluation."""
    logging = {
        'dir': None,
        'name': None,
        'output_csv': 'results.csv',
    }
    model = {
        'name': None,
    }
    base = {
        'use_density_fitting': True,
        'spin_restricted': True,
        'splits': ('test',),
    }
    evaluation_checkpoints = {
        'chkp_type': 'dynamic_train',
        'root_dir': None,
        'name': None,
        'indices': None,
    }
    basis = {
        'name': 'def2-SVP',
        'derivative': 1,
    }
    alignment = {
        'atom': 1,
        'basis': 1,
        'grid': 32 * 1024,
    }
    quadrature = {
        'level': 1,
    }
    initial_density_guess = {
        'base_initial_density_guess_key': 'minao',
        'initial_ref_density_method': None,
        'min_ref_density_interpolation': 0.0,
        'max_ref_density_interpolation': 0.0,
        'noise_eps': 0.0,
    }
    data = {
        'batch_size': 1,
        'shuffle': False,
        'workers': 3,
        'data_split_seed': 0,
        'n_test_samples': None,
        'preload': {
            'center': False,
        },
        'split': {'val_fraction': 0.1},
    }
    solver = {
        'method': 'scf',
        'convergence_threshold': 1e-7,
        'kwargs': {
            'cycles': 15,
            'convergence_acceleration_method': 'DIIS',
        },
    }
    metrics = {
        'energy_shift_mode': 'none',  # median or none
    }


@ex.named_config
def scaled_nagai2020():
    model = {
        'name': 'nagai2020',
        'local_n_layers': 4,
        'local_hidden_dim': 23,
    }


@ex.named_config
def scaled_skala_mgga():
    model = {
        'name': 'skala_mgga',
        'local_n_layers': 4,
        'local_hidden_dim': 22,
    }


class EnergyEvaluationExperiment:
    """Simplified evaluation comparing model energies to original dataset labels."""

    output_dir: Path
    output_csv: Path
    metrics_summary_path: Path

    basis_name: str
    basis_derivative: int
    use_density_fitting: bool

    dataset: dataloading.BaseDataset
    dataloaders: dataloading.DataLoaders
    init_psys: systems.PreloadSystem

    preload_transform: Sequence[grain.Transformation]
    init_preloaded_basis_fns: discretization.PreloadedGTOBasis
    main_thread_transform: dataloading.ToJaxTransform

    alignment: Alignment
    grid_level: int
    spin_restricted: bool

    base_initial_density_guess_key: BaseInitialGuess
    solver: SelfConsistentFieldSolver
    _solver_apply: Callable[..., Any]

    evaluation_splits: Tuple[str, ...]

    @ex.capture(prefix='logging')  # type: ignore
    def __init__(
        self,
        overwrite: int,
        dir: str,
        output_csv: str | None,
        name: str,
    ) -> None:
        self.output_dir = results_directory(dir, name, exists_ok=True)
        self.output_csv = results_metrics_csv_path(self.output_dir, output_csv)
        self.metrics_summary_path = results_summary_path(self.output_dir)
        self.set_all()
        self.get_initial_density_matrix_fn()  # type: ignore
        self.init_solver()  # type: ignore

    def set_all(self) -> None:
        self.set_base()  # type: ignore
        self.set_alignment()  # type: ignore
        self.set_quadrature()  # type: ignore
        self.set_basis()  # type: ignore
        self.set_evaluation_checkpoints()  # type: ignore

    @ex.capture(prefix='evaluation_checkpoints')  # type: ignore
    def set_evaluation_checkpoints(
        self, chkp_type: str, root_dir: str, name: str, indices: List[int]
    ) -> None:
        self.chkp_type = chkp_type
        self.chkp_root_dir = root_dir
        self.chkp_name = name
        self.chkp_indices = indices

    @ex.capture(prefix='base')  # type: ignore
    def set_base(
        self,
        use_density_fitting: bool,
        spin_restricted: bool,
        splits: Tuple[str, ...],
    ):
        self.spin_restricted = spin_restricted
        self.use_density_fitting = use_density_fitting
        self.evaluation_splits = splits

    @ex.capture(prefix='alignment')  # type: ignore
    def set_alignment(self, atom: int, basis: int, grid: int) -> None:
        self.alignment = Alignment(atom, basis, grid)

    @ex.capture(prefix='quadrature')  # type: ignore
    def set_quadrature(self, level: int) -> None:
        self.grid_level = level

    @ex.capture(prefix='basis')  # type: ignore
    def set_basis(self, name: str, derivative: int) -> None:
        self.basis_name = name
        self.basis_derivative = derivative

    @ex.capture(prefix='initial_density_guess')  # type: ignore
    def get_initial_density_matrix_fn(
        self,
        base_initial_density_guess_key: BaseInitialGuess,
        initial_ref_density_method: Dict[str, Any] | None,
        min_ref_density_interpolation: float,
        max_ref_density_interpolation: float,
        noise_eps: float,
    ) -> None:
        self.base_initial_density_guess_key = base_initial_density_guess_key

        if initial_ref_density_method is None:
            self.initial_ref_density_method_key = None
            self.initial_ref_density_method_kwargs = None
        else:
            self.initial_ref_density_method_key = initial_ref_density_method['key']
            self.initial_ref_density_method_kwargs = initial_ref_density_method['kwargs']

        self.initial_density_matrix_fn = dataloading.get_initial_density_matrix_fn(
            min_ref_density_interpolation,
            max_ref_density_interpolation,
            noise_eps,
        )

        self._init_dataset()  # type: ignore

    @ex.capture(prefix='data')  # type: ignore
    def _init_dataset(
        self,
        key: str,
        data_set_kwargs: Dict[str, Any],
        data_split_seed: int,
        split: Dict[str, Any],
        n_test_samples: int | None,
        preload: Dict[str, bool],
        batch_size: int,
        shuffle: bool,
        workers: int,
    ) -> None:
        self.data_split_seed = data_split_seed
        self.data_split = split
        self.n_test_samples = n_test_samples
        self.preload = preload
        self.preload_center = preload['center']
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.workers = workers

        # Use base dataset directly - no DEIXCDataset wrapper
        self.dataset = dataloading.key_to_dataset[key.lower()](
            initial_ref_density_method_key=self.initial_ref_density_method_key,
            initial_ref_density_method_kwargs=self.initial_ref_density_method_kwargs,
            **data_set_kwargs,
        )

        self._init_dataloader()
        self._init_main_thread_transform()

    def _init_dataloader(self) -> None:
        dataset_ensemble = DatasetEnsemble.infer_split(
            self.dataset, data_split_seed=self.data_split_seed, **self.data_split
        )

        self.max_angular_momentum, basis_fn_preloader = discretization.get_gto_preloader(
            self.basis_name, self.dataset.unique_elements
        )
        self.preload_transform = dataloading.get_preload_transform(
            self.batch_size,
            self.basis_name,
            self.spin_restricted,
            self.alignment,
            self.use_density_fitting,
            base_initial_density_guess=self.base_initial_density_guess_key,
            basis_fn_preloader=basis_fn_preloader,
            center=self.preload_center,
        )

        self.init_psys, self.dataloaders = get_psys_and_dataloaders(
            datasets=dataset_ensemble,
            transformations=self.preload_transform,
            shuffle=False,
            shuffling_seed=0,
            workers=self.workers,
            worker_buffer_size=1,
            n_test_samples=self.n_test_samples,
        )
        self.init_preloaded_basis_fns = basis_fn_preloader(self.init_psys.atom_z)

    def _init_main_thread_transform(self) -> None:
        grid_fn = discretization.get_grid_fn(
            self.grid_level, self.dataset.unique_elements, self.alignment.grid
        )
        basis_fn = discretization.get_gto_grid_eval_fn(
            deriv=self.basis_derivative,
            max_angular_momentum=self.max_angular_momentum,
        )
        self.main_thread_transform = dataloading.get_jax_transform(
            grid_fn,
            basis_fn,
        )

    @ex.capture  # type: ignore
    def init_solver(self, solver: Dict[str, Any], model: Dict[str, Any]) -> None:
        model_module = self._xc_module()  # type: ignore
        solver_type = solver['method']
        solver_kwargs = solver['kwargs']
        if solver_type != 'scf':
            raise ValueError(f'Unsupported solver type: {solver_type}')
        self.convergence_threshold = solver['convergence_threshold']
        self.solver = SelfConsistentFieldSolver(
            model_module,
            use_density_fitting=self.use_density_fitting,
            spin_restricted=self.spin_restricted,
            **solver_kwargs,
        )
        self._solver_apply = jax.jit(self.solver.apply)

    @ex.capture  # type: ignore
    def _xc_module(self, model) -> xc_energy.XCModule:
        functional = xc_energy.functionals.get_functional(**model)
        requires_spin_resolved = getattr(
            functional, 'requires_spin_resolved_features', False
        )
        feature_fn = xc_energy.features.DensityFeatures(
            self.spin_restricted, spin_resolved=requires_spin_resolved
        )
        module = xc_energy.XCModule(functional, feature_fn)
        return module

    def get_checkpoint_name(self, idx: int) -> str:
        if self.chkp_type == 'finetune':
            return self.chkp_name
        return f'{self.chkp_name}_{idx}'

    def get_checkpoint_path(self, idx: int) -> str:
        checkpoint_dir = Path(self.chkp_root_dir) / self.get_checkpoint_name(idx)
        split_dirs = sorted(checkpoint_dir.glob('split_0*'))
        if not split_dirs:
            return str(checkpoint_dir / 'split_0')
        return str(split_dirs[-1])

    def load_params(self, idx: int) -> Any:
        if self.chkp_type == 'finetuned_epoch':
            checkpoint_file_path = Path(self.chkp_root_dir) / f'params_epoch_{idx}.pkl'
            if not checkpoint_file_path.exists():
                raise FileNotFoundError(f'Checkpoint not found at {checkpoint_file_path}')
            with checkpoint_file_path.open('rb') as stream:
                return pickle.load(stream)

        checkpoint_file_path = Path(
            checkpoint_best_path(self.get_checkpoint_path(idx), self.chkp_type)
        )
        if not checkpoint_file_path.exists():
            raise FileNotFoundError(f'Checkpoint not found at {checkpoint_file_path}')
        with checkpoint_file_path.open('rb') as stream:
            params = pickle.load(stream)
        copy_directory_to_directory(
            self.get_checkpoint_path(idx),
            self.output_dir / self.get_checkpoint_name(idx) / 'checkpoint',
        )
        return params

    def __call__(self) -> None:
        print(f'\n{"=" * 80}')
        print(f'Energy-only evaluation: {self.chkp_indices} checkpoint(s)')
        print(f'Checkpoint type: {self.chkp_type}')
        print(f'{"=" * 80}\n')

        for idx in self.chkp_indices:
            print(f'\n{"=" * 80}')
            print(f'Evaluating: {self.get_checkpoint_name(idx)}')
            print(f'{"=" * 80}\n')

            checkpoint_output_dir = self.output_dir / self.get_checkpoint_name(idx)
            checkpoint_output_dir.mkdir(parents=True, exist_ok=True)

            original_output_csv = self.output_csv
            original_metrics_summary = self.metrics_summary_path
            self.output_csv = checkpoint_output_dir / original_output_csv.name
            self.metrics_summary_path = (
                checkpoint_output_dir / original_metrics_summary.name
            )

            try:
                params = self.load_params(idx)
                self.evaluate(params, self.evaluation_splits)
                print(f'\nCompleted evaluation of training run {idx}')
                print(f'Results saved to: {self.output_csv}')

            except Exception as e:
                print(f'Error evaluating training run {idx}: {e}')
                import traceback

                traceback.print_exc()

            finally:
                self.output_csv = original_output_csv
                self.metrics_summary_path = original_metrics_summary

        print(f'\n{"=" * 80}')
        print(f'Evaluation complete. Results saved to: {self.output_dir}')
        print(f'{"=" * 80}')

    def evaluate(
        self,
        params: Any,
        splits: Sequence[str],
    ) -> None:
        records = []
        for split in splits:
            loader = getattr(self.dataloaders, split)
            for psys, basis_fns, targets in loader:
                record = self._evaluate_sample(params, split, psys, basis_fns, targets)
                records.append(record)

        self._write_csv(records)  # type: ignore

    def _evaluate_sample(
        self,
        params: Any,
        split: str,
        psys: systems.PreloadSystem,
        basis_fns,
        targets: Targets,
    ) -> Dict[str, Any]:
        """Run SCF and compute energy error against original dataset label."""
        density_matrices, sys = self.main_thread_transform(psys, basis_fns)
        initial_dm, _ = self.initial_density_matrix_fn(density_matrices, None)

        # SelfConsistentFieldSolver returns (energies, density_matrices)
        (e_hj, e_xc), _ = self._solver_apply(params, initial_dm, sys)
        total_energies = e_xc + e_hj + systems.nuclear_energy_fn(sys._nuc_pos, sys)

        # Convergence check
        delta_total_pred = jnp.abs(total_energies[1:] - total_energies[:-1])
        converged_mask = delta_total_pred < self.convergence_threshold
        cycles_to_convergence = jnp.where(
            jnp.any(converged_mask),
            jnp.argmax(converged_mask) + 1,
            total_energies.shape[0],
        )
        did_converge = bool(jnp.any(converged_mask))
        total_pred = total_energies[-1]

        assert targets.energy is not None
        target_energy = targets.energy.item()  # type: ignore

        record = {
            'split': split,
            'idx': int(psys.idx),
            'did_converge': did_converge,
            'cycles_to_convergence': int(cycles_to_convergence),  # type: ignore
            'predicted_energy_Eh': float(total_pred),
            'target_energy_Eh': target_energy,
            'energy_mae_mEh': float(jnp.abs(total_pred - target_energy) * 1e3),
        }

        return record

    @ex.capture(prefix='metrics')  # type: ignore
    def _write_csv(
        self, records: Iterable[Dict[str, Any]], energy_shift_mode: str
    ) -> None:
        fieldnames = [
            'split',
            'idx',
            'did_converge',
            'cycles_to_convergence',
            'predicted_energy_Eh',
            'target_energy_Eh',
            'energy_mae_mEh',
        ]
        records = list(records)
        df = pd.DataFrame.from_records(records, columns=fieldnames)
        df.to_csv(self.output_csv, index=False)

        # Summary statistics
        total_samples = len(df)
        converged_samples = df['did_converge'].sum()
        difference = df['target_energy_Eh'] - df['predicted_energy_Eh']
        if energy_shift_mode == 'median':
            optimal_shift = difference.median()
        elif energy_shift_mode == 'none':
            optimal_shift = 0.0
        else:
            raise ValueError(f'Unknown metrics.energy_shift_mode: {energy_shift_mode}')

        shifted_difference = difference - optimal_shift
        mae_shifted = shifted_difference.abs() * 1e3
        mean_mae = mae_shifted.mean()
        median_mae = mae_shifted.median()
        mean_difference = difference.mean() * 1e3
        median_difference = difference.median() * 1e3
        optimal_shift_mEh = optimal_shift * 1e3

        summary_lines = [
            'Energy-only evaluation summary:',
            f'  Total samples: {total_samples}',
            f'  Converged samples: {converged_samples}',
            f'  Mean energy difference (mEh): {mean_difference:.6f}',
            f'  Median energy difference (mEh): {median_difference:.6f}',
            f'  Optimal constant shift (mEh): {optimal_shift_mEh:.6f}',
            f'  Mean energy MAE (mEh): {mean_mae:.6f}',
            f'  Median energy MAE (mEh): {median_mae:.6f}',
        ]

        print('\n'.join(summary_lines))
        self.metrics_summary_path.write_text('\n'.join(summary_lines) + '\n')


@ex.automain
def main(overwrite: int):
    experiment = EnergyEvaluationExperiment(overwrite)  # type: ignore
    experiment()
