"""Utilities for comparing Dei-XC targets across reference methods.

This script loads cached reference data generated for the Dei-XC project, runs
pairwise comparisons between a designated reference method and one or more
comparison methods, and reports summary statistics as well as per-sample CSV
records.  The Sacred/SEML configuration at the bottom of the file exposes the
most common experiment flags.
"""

import os
from dataclasses import dataclass
from typing import Any, Callable, Mapping, Sequence

import jax.numpy as jnp
import numpy as onp
import pandas as pd
from jax import config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from deixc.dataset import DEIXCTargets
from deixc.training.loss.xc_potential import analytic_xc_potential_loss
from egxc.dataloading import BaseDataset, key_to_dataset
from egxc.dataloading.datasets.base import RawInput
from egxc.dataloading.io import auxiliary_data_paths
from egxc.dataloading.utils import IndexWrapper
from egxc.solver.fock import get_coulomb_fn, mean_field_energy
from egxc.systems.base import Grid, System
from egxc.systems.preload import preload_system_using_pyscf
from egxc.training.loss.density import get_density_mean_field_error_fn
from egxc.utils.typing import (
    Alignment,
    MethodKey,
    NpDensityMatrix,
    NpFloatBxB,
    NpFloatRefSCFxBxB,
)

config.update('jax_enable_x64', True)
config.update('jax_default_matmul_precision', 'float32')
compilation_cache.set_cache_dir('./caches/jax_compile/compute_deixc_targets')

ex = Experiment()

EPSILON = 1e-15

DEFAULT_DATA_KWARGS = {
    'initial_ref_density_method_key': None,
    'initial_ref_density_method_kwargs': None,
}

# --- Experiment configuration -------------------------------------------------


@ex.config
def base_config():
    align_scf_trajectory = 15
    shift_dispersion = False
    output_path = 'ANONYMOUS_DIR'
    only_last = True

    FUNCTIONAL = 'B3LYP'
    BASIS = 'def2-TZVPD'
    KEY = 'ks_dft'
    dataset = {
        'key': 'qm9',
        'n_samples': 10,
        'seed': 0,
        'kwargs': {
            'data_dir': 'ANONYMOUS_DIR',
            'heavy_atoms_thresh': 4,
            'exclude_fluorine': False,
            **DEFAULT_DATA_KWARGS,
        },
    }
    reference_method = {
        'label': '',
        'key': KEY,
        'kwargs': {
            'xc_str': FUNCTIONAL,
            'basis': BASIS,
            'backend': 'pyscf',
            'use_eri_density_fitting': True,
            'use_exchange_density_fitting': True,
            'spin_restricted': True,
            'quadrature_grid_level': 1,
        },
    }
    comparison_methods = [
        {
            'label': 't',
            'key': KEY,
            'kwargs': {
                'xc_str': FUNCTIONAL,
                'basis': BASIS,
                'backend': 'pyscf',
                'use_eri_density_fitting': True,
                'use_exchange_density_fitting': True,
                'spin_restricted': True,
                'quadrature_grid_level': 1,
            },
        },
        {
            'label': 'c',
            'key': KEY,
            'kwargs': {
                'xc_str': FUNCTIONAL,
                'basis': BASIS,
                'backend': 'custom',
                'use_eri_density_fitting': True,
                'use_exchange_density_fitting': True,
                'spin_restricted': True,
                'quadrature_grid_level': 1,
            },
        },
    ]
    # Test all metrics that are not None in DEIXCTargets
    metrics = (
        # Along SCF trajectory:
        'total_energy',
        'xc_energy',
        'xc_potential_matrix',
        'mo_coeffs',  # a.k.a. C  (SCF, B, B)
        'density_matrix',
        'linear_response_xc_pot',  # (SCF, O, V)
        # From ground-state density only:
        # "density_hessian_diagonal",
        # "forces",
    )
    summary_statistics = ('mean', 'median', 'std', 'max')


ENERGY_METRICS = {
    'total_energy',
    'xc_energy',
}

DENSITY_METRICS = {
    'density_matrix',
    'density_matrices',
}

# Attributes whose leading axis enumerates the SCF iterations. When
# `only_last` is enabled we compare only the converged values for these metrics.
SCF_TRAJECTORY_METRICS = {
    'mo_coeffs',
    'total_energies',
    'xc_energies',
    'xc_potential_matrices',
    'linear_response_xc_pot',
    'density_matrices',
    'direct_minimization_directions',
}


ANALYTIC_METRIC_FUNS = {
    'xc_potential_matrix': analytic_xc_potential_loss,
    'xc_potential_matrices': analytic_xc_potential_loss,
    'linear_response_xc_pot': analytic_xc_potential_loss,
}


@dataclass(frozen=True)
class MetricStats:
    target: float
    comparator: float
    relative: float
    absolute: float

    def record(self) -> dict[str, float]:
        return {
            'target': self.target,
            'comparator': self.comparator,
            'relative': self.relative,
            'absolute': self.absolute,
        }


SUMMARY_AGGREGATORS: Mapping[str, Callable[[pd.Series], float]] = {
    'mean': lambda s: float(s.mean()),
    'std': lambda s: float(s.std(ddof=0)),
    'var': lambda s: float(s.var(ddof=0)),
    'median': lambda s: float(s.median()),
    'max': lambda s: float(s.max()),
    'min': lambda s: float(s.min()),
}


@dataclass(frozen=True)
class MethodSpec:
    label: str
    key: MethodKey
    kwargs: dict[str, Any]
    aux_dir: str


@ex.capture(prefix='dataset')  # type: ignore
def get_dataset(
    key: str | None,
    n_samples: int | None,
    seed: int,
    kwargs: dict[str, Any],
) -> BaseDataset:
    """Create the configured dataset and optionally subsample it."""

    dataset_key = key or 'qm9'

    def sub_sample(dataset: BaseDataset, limit: int) -> BaseDataset:
        rng = onp.random.RandomState(seed)
        indices = rng.permutation(len(dataset))[:limit]
        return IndexWrapper(dataset, indices.tolist())

    dataset = key_to_dataset[dataset_key](**kwargs)
    if dataset_key == 'qm9':
        dataset, _, _ = dataset.random_split(val_fraction=0.0, seed=seed)
    if n_samples is not None:
        dataset = sub_sample(dataset, n_samples)
    return dataset


class ComparisonRunner:
    """Handle loading targets, computing errors, and reporting results."""

    dataset: BaseDataset
    reference: MethodSpec
    comparators: Sequence[MethodSpec]
    metrics: Sequence[str]
    summary_statistics: Sequence[str]
    align_scf_trajectory: int
    shift_dispersion: bool
    only_last: bool
    output_path: str | None
    comparator_labels: Sequence[str]
    _system_cache: dict[int, System]

    def _method_spec(self, method_cfg: Mapping[str, Any]) -> MethodSpec:
        kwargs = dict(method_cfg.get('kwargs', {}))
        aux_dir = auxiliary_data_paths(
            self.dataset.auxiliary_data_directory,
            'deixc',
            method_cfg['key'],
            **kwargs,
        )
        return MethodSpec(
            label=method_cfg.get('label', ''),
            key=method_cfg['key'],
            kwargs=kwargs,
            aux_dir=aux_dir,
        )

    @ex.capture
    def __init__(
        self,
        reference_method,
        comparison_methods,
        metrics: Sequence[str],
        summary_statistics: Sequence[str],
        align_scf_trajectory: int,
        shift_dispersion: bool,
        only_last: bool,
        output_path: str | None,
    ) -> None:
        self.dataset = get_dataset()  # type: ignore
        self.reference = self._method_spec(reference_method)
        self.comparators = tuple(
            self._method_spec(method_cfg) for method_cfg in comparison_methods
        )
        self.metrics = tuple(metrics)
        self.summary_statistics = tuple(summary_statistics)
        self.align_scf_trajectory = align_scf_trajectory
        self.shift_dispersion = shift_dispersion
        self.only_last = only_last
        self.output_path = output_path
        self.comparator_labels = tuple(spec.label for spec in self.comparators)
        self.spin_restricted = bool(self.reference.kwargs.get('spin_restricted', True))
        self.use_density_fitting = self.reference.kwargs['use_eri_density_fitting']
        self._density_mf_error_fn = get_density_mean_field_error_fn(
            spin_restricted=self.spin_restricted,
            use_density_fitting=self.use_density_fitting,
            scale_per_electron=False,
        )
        self._coulomb_fn = get_coulomb_fn(
            spin_restricted=self.spin_restricted,
            use_density_fitting=self.use_density_fitting,
        )
        self._system_cache = {}

    def _load_target(self, spec: MethodSpec, sample_idx: int) -> DEIXCTargets:
        path = os.path.join(spec.aux_dir, f'{sample_idx}.npz')
        with onp.load(path, allow_pickle=True) as payload:
            data = {key: payload[key] for key in payload.files}
        return DEIXCTargets.create(
            data,
            self.align_scf_trajectory,
            shift_dispersion=self.shift_dispersion,
        )

    def _get_system_data(self, sample_idx: int, raw_input: RawInput | None) -> System:
        """Cache overlap, electron count, and System object for analytic metrics."""

        cached = self._system_cache.get(sample_idx)
        if cached is not None:
            return cached
        if raw_input is None:
            raise ValueError(
                'Raw input (geometry, charge, spin) required to evaluate analytic metrics.'
            )
        nuc_pos, atom_z, charge, spin, reference_density = raw_input
        basis = self.reference.kwargs['basis']

        preload = preload_system_using_pyscf(
            idx=sample_idx,
            nuc_pos=nuc_pos,
            atom_z=atom_z,
            charge=int(charge),
            spin=int(spin),
            reference_density=reference_density,
            basis=basis,
            spin_restricted=self.spin_restricted,
            alignment=Alignment(atom=1, basis=1, grid=1),
            base_initial_density_guess='minao',
            use_density_fitting=self.use_density_fitting,
            cache_pyscf_mole=True,
        )
        sys = System.from_preloaded(preload, Grid.empty())
        self._system_cache[sample_idx] = sys
        return sys

    def _project_linear_response_to_ao(
        self,
        sample_idx: int,
        raw_input: RawInput | None,
        ref_target: DEIXCTargets,
        other_target: DEIXCTargets,
    ) -> tuple[NpFloatRefSCFxBxB, NpFloatRefSCFxBxB]:
        """Project linear-response potentials into the AO basis using overlap."""

        system_data = self._get_system_data(sample_idx, raw_input)
        S = system_data.fock_tensors.overlap
        ref_value = onp.asarray(ref_target.get_linear_response_xc_pot_in_ao_basis(S))
        other_value = onp.asarray(other_target.get_linear_response_xc_pot_in_ao_basis(S))
        return ref_value, other_value

    def _maybe_select_scf_iteration(
        self,
        metric: str,
        ref_value: onp.ndarray,
        other_value: onp.ndarray,
    ) -> tuple[onp.ndarray, onp.ndarray]:
        """Slice SCF trajectories according to the configuration."""

        if metric not in SCF_TRAJECTORY_METRICS:
            return ref_value, other_value

        if self.only_last:
            return ref_value[-1], other_value[-1]

        return ref_value, other_value

    def _evaluate_density_mean_field_error(
        self,
        sample_idx: int,
        raw_input: RawInput | None,
        dm: NpDensityMatrix,
        other_dm: NpDensityMatrix,
    ) -> MetricStats:
        """Compute density-driven mean-field error in energy units (Hartree)."""

        sys = self._get_system_data(sample_idx, raw_input)

        def _final_density(matrix: onp.ndarray) -> onp.ndarray:
            array = onp.asarray(matrix)
            if array.ndim == 2:
                return array
            if array.ndim == 3:
                if array.shape[0] == 2 and array.shape[1] == array.shape[2]:
                    return array.sum(axis=0)
                return array[-1]
            if array.ndim == 4:
                array = array[-1]
                if (
                    array.ndim == 3
                    and array.shape[0] == 2
                    and array.shape[1] == array.shape[2]
                ):
                    return array.sum(axis=0)
                if array.ndim == 3:
                    return array[-1]
                return array
            raise ValueError(f'Unsupported density matrix shape: {array.shape}')

        def _mean_field_energy(matrix: onp.ndarray) -> float:
            coulomb = self._coulomb_fn(jnp.asarray(matrix), sys.fock_tensors.ert)
            energy = mean_field_energy(
                jnp.asarray(matrix), coulomb, sys.fock_tensors.core_hamiltonian
            )
            return float(energy)

        assert dm.shape == other_dm.shape
        ref_density = _final_density(dm)
        other_density = _final_density(other_dm)

        ref_energy = _mean_field_energy(ref_density)
        other_energy = _mean_field_energy(other_density)
        diff_scalar = float(
            self._density_mf_error_fn(
                sys,
                jnp.asarray(ref_density),
                jnp.asarray(other_density),
            )
        )
        denom = abs(ref_energy) if abs(ref_energy) > 1e-12 else 1e-12
        return MetricStats(
            target=ref_energy,
            comparator=other_energy,
            relative=diff_scalar / denom,
            absolute=diff_scalar,
        )

    def _evaluate_analytic_metric(
        self,
        metric: str,
        analytic_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], Any],
        sample_idx: int,
        raw_input: RawInput | None,
        ref_value: onp.ndarray,
        other_value: onp.ndarray,
    ) -> MetricStats:
        """Evaluate analytic losses when available for tensor-valued metrics."""

        system_data = self._get_system_data(sample_idx, raw_input)
        overlap: NpFloatBxB = onp.asarray(system_data.fock_tensors.overlap)
        n_electrons = int(onp.asarray(system_data.n_electrons))

        if overlap.shape[-1] != ref_value.shape[-1]:
            target_dim = ref_value.shape[-1]
            current_dim = overlap.shape[-1]
            if current_dim < target_dim:
                pad = target_dim - current_dim
                overlap = onp.pad(overlap, ((0, pad), (0, pad)))
                for diag in range(current_dim, target_dim):
                    overlap[diag, diag] = 1.0
            else:
                overlap = overlap[:target_dim, :target_dim]
        overlap_jnp = jnp.asarray(overlap)
        n_electrons_jnp = jnp.asarray(n_electrons, dtype=jnp.uint32)

        def _loss(a: onp.ndarray, b: onp.ndarray) -> float:
            return float(
                analytic_fn(
                    jnp.asarray(a),
                    jnp.asarray(b),
                    overlap_jnp,
                    n_electrons_jnp,
                )
            )

        assert ref_value.shape == other_value.shape
        if ref_value.ndim == 2:
            diff_scalar = _loss(ref_value, other_value)
        else:
            assert ref_value.ndim == 3
            assert ref_value.shape[0] == 2
            diff_scalar = float(
                onp.mean([_loss(r, o) for r, o in zip(ref_value, other_value)])
            )

        ref_norm = float(onp.linalg.norm(ref_value))
        other_norm = float(onp.linalg.norm(other_value))
        denom = ref_norm if ref_norm > 1e-12 else 1e-12
        rel = diff_scalar / denom
        return MetricStats(
            target=ref_norm,
            comparator=other_norm,
            relative=rel,
            absolute=diff_scalar,
        )

    def _compute_errors(
        self,
        sample_idx: int,
        raw_input: RawInput | None,
        ref_target: DEIXCTargets,
        other_target: DEIXCTargets,
        metric: str,
    ) -> MetricStats:
        """Compute error statistics for a metric between reference and comparator."""

        values: list[onp.ndarray] = []
        for target_obj, label in (
            (ref_target, 'reference'),
            (other_target, 'comparator'),
        ):
            if not hasattr(target_obj, metric):
                available = ', '.join(
                    sorted(name for name in dir(target_obj) if not name.startswith('_'))
                )
                raise KeyError(f'Unknown metric {metric}. Available metrics: {available}')
            value = getattr(target_obj, metric)
            if value is None:
                raise ValueError(f'Metric {metric} is None for the {label} target')
            if callable(value):
                raise ValueError(
                    f'Metric {metric} is callable on the {label} target; provide data attribute'
                )
            values.append(onp.asarray(value))

        ref_value, other_value = values
        analytic_fn = ANALYTIC_METRIC_FUNS.get(metric)
        if metric == 'linear_response_xc_pot' and analytic_fn is not None:
            ref_value, other_value = self._project_linear_response_to_ao(
                sample_idx,
                raw_input,
                ref_target,
                other_target,
            )

        ref_value, other_value = self._maybe_select_scf_iteration(
            metric,
            ref_value,
            other_value,
        )

        if metric in DENSITY_METRICS:
            return self._evaluate_density_mean_field_error(
                sample_idx,
                raw_input,
                ref_value,
                other_value,
            )

        if analytic_fn is not None:
            return self._evaluate_analytic_metric(
                metric,
                analytic_fn,
                sample_idx,
                raw_input,
                ref_value,
                other_value,
            )

        ref_arr = onp.asarray(ref_value, dtype=onp.float64)
        other_arr = onp.asarray(other_value, dtype=onp.float64)
        if ref_arr.ndim == 0:
            target = float(ref_arr)
            absolute = float(onp.abs(other_arr - ref_arr))
            comparator = float(onp.abs(other_arr))
        else:
            target = float(onp.linalg.norm(ref_arr))
            absolute = float(onp.linalg.norm(other_arr - ref_arr))
            comparator = float(onp.linalg.norm(other_arr))
        denom = target if target > EPSILON else EPSILON
        rel = absolute / denom
        return MetricStats(
            target=target,
            comparator=comparator,
            relative=rel,
            absolute=absolute,
        )

    def __call__(self) -> dict[str, Any]:
        """Run the configured comparison and return summary/per-sample data."""

        n_samples = len(self.dataset)
        if n_samples == 0:
            raise RuntimeError('Dataset is empty; nothing to compare.')

        records: list[dict[str, Any]] = []

        for i in range(n_samples):
            sample = self.dataset[i]
            if not isinstance(sample, tuple) or not sample:
                raise ValueError('Dataset must return at least an index per sample.')
            sample_idx = int(sample[0])
            raw_input = sample[1] if len(sample) > 1 else None
            if raw_input is not None and (
                not isinstance(raw_input, tuple) or len(raw_input) != 5
            ):
                raw_input = None

            ref_target = self._load_target(self.reference, sample_idx)
            if ref_target is None:
                continue

            for spec in self.comparators:
                comp_target = self._load_target(spec, sample_idx)
                if comp_target is None:
                    continue
                for metric in self.metrics:
                    stats = self._compute_errors(
                        sample_idx,
                        raw_input,
                        ref_target,
                        comp_target,
                        metric,
                    )
                    record = {
                        'sample_idx': sample_idx,
                        'method': spec.label,
                        'metric': metric,
                        **stats.record(),
                    }
                    records.append(record)

        if records:
            records_df = pd.DataFrame.from_records(records)
            summary_df = self._compute_summary(records_df)
            per_sample_df = self._build_per_sample_df(records_df)
        else:
            empty_index = pd.MultiIndex.from_arrays([[], []], names=['method', 'metric'])
            summary_df = pd.DataFrame(
                columns=list(self.summary_statistics),  # type: ignore
                index=empty_index,
                dtype=float,
            )
            per_sample_df = pd.DataFrame(columns=['sample_idx'])  # type: ignore

        results: dict[str, Any] = {
            'summary': summary_df,
            'per_sample': per_sample_df,
        }

        if self.output_path is not None:
            self._write_csv(per_sample_df)

        return results

    def _compute_summary(self, records_df: pd.DataFrame) -> pd.DataFrame:
        grouped = records_df.groupby(['method', 'metric'])['absolute']

        stats: dict[str, pd.Series] = {}
        for stat in self.summary_statistics:
            aggregator = SUMMARY_AGGREGATORS.get(stat)
            if aggregator is None:
                available = ', '.join(sorted(SUMMARY_AGGREGATORS))
                raise KeyError(
                    f'Unknown summary statistic {stat}. Available: {available}'
                )
            stats[stat] = grouped.apply(aggregator)  # type: ignore

        if not stats:
            empty_index = pd.MultiIndex.from_arrays([[], []], names=['method', 'metric'])
            return pd.DataFrame(
                columns=list(self.summary_statistics),  # type: ignore
                index=empty_index,
                dtype=float,
            )

        return pd.DataFrame(stats).sort_index()

    def _build_per_sample_df(self, records_df: pd.DataFrame) -> pd.DataFrame:
        if records_df.empty:
            return pd.DataFrame(columns=['sample_idx'])  # type: ignore

        df = records_df.copy()
        export = df['absolute'].astype(float)

        df = df.assign(export_value=export)

        target = df.pivot_table(
            index='sample_idx',
            columns='metric',
            values='target',
            aggfunc='last',
        )

        values = df[['sample_idx', 'metric', 'method', 'export_value']].pivot(
            index='sample_idx', columns=['metric', 'method'], values='export_value'
        )
        if not values.empty:
            renamed = []
            for metric, method in values.columns:
                if metric in ENERGY_METRICS | DENSITY_METRICS:
                    renamed.append(f'Δ{metric}_{method}_Ha')
                else:
                    renamed.append(f'Δ{metric}_{method}')
            values.columns = renamed

        per_sample_df = (
            pd.concat([target, values], axis=1)
            .reset_index()
            .sort_values('sample_idx')
            .reset_index(drop=True)
        )
        return per_sample_df

    def report(self, results: Mapping[str, Any]) -> None:
        """Pretty-print summary statistics to stdout."""

        summary_df: pd.DataFrame = results.get('summary', pd.DataFrame())
        if summary_df.empty:
            print('No comparison results (possibly due to missing data).')
            return

        metric_order = [
            metric
            for metric in self.metrics
            if metric in summary_df.index.get_level_values('metric')
        ]
        if not metric_order:
            metric_order = list(summary_df.index.get_level_values('metric').unique())

        for idx, stat in enumerate(self.summary_statistics):
            if stat not in summary_df.columns:
                continue
            stat_values = summary_df[stat].unstack('metric').reindex(columns=metric_order)  # type: ignore
            if stat_values.empty:
                continue
            ordered_rows = stat_values.reindex(self.comparator_labels)
            extras = stat_values.loc[~stat_values.index.isin(self.comparator_labels)]
            combined = pd.concat([ordered_rows, extras])
            combined = combined.loc[~combined.index.duplicated(keep='first')]
            display = combined.map(
                lambda value: 'n/a' if pd.isna(value) else f'{float(value):.6e}'
            )
            banner = f'\033[94m########## {stat.upper()} ##########\033[0m'
            if idx > 0:
                print()
            print(banner)

            header_cells = ['Method', *metric_order]
            data_rows: list[list[str]] = []
            for method, values in display.iterrows():
                row = [method, *values.tolist()]
                data_rows.append(row)

            columns = [header_cells, *data_rows]
            widths = [
                max(len(row[col_idx]) for row in columns)
                for col_idx in range(len(header_cells))
            ]

            def _format_row(row: list[str]) -> str:
                cells = []
                for col_idx, value in enumerate(row):
                    aligned = (
                        value.ljust(widths[col_idx])
                        if col_idx == 0
                        else value.rjust(widths[col_idx])
                    )
                    cells.append(aligned)
                return ' | '.join(cells)

            separator = '-+-'.join('-' * width for width in widths)

            print(_format_row(header_cells))
            print(separator)
            for row in data_rows:
                print(_format_row(row))

    def _write_csv(self, per_sample_df: pd.DataFrame) -> None:
        """Write per-sample results to a CSV file if requested."""
        assert self.output_path is not None
        os.makedirs(os.path.dirname(self.output_path) or '.', exist_ok=True)
        column_order = ['sample_idx']
        for metric in self.metrics:
            column_order.append(f'{metric}')
            if metric in ENERGY_METRICS | DENSITY_METRICS:
                column_order.extend(
                    f'Δ{metric}_{label}_Ha' for label in self.comparator_labels
                )
            else:
                column_order.extend(
                    f'Δ{metric}_{label}' for label in self.comparator_labels
                )

        cleaned_df = per_sample_df.loc[:, ~per_sample_df.columns.duplicated()]
        ordered_df = cleaned_df.reindex(columns=column_order)
        extra_columns = [col for col in cleaned_df.columns if col not in column_order]
        if extra_columns:
            ordered_df = pd.concat([ordered_df, cleaned_df[extra_columns]], axis=1)

        ordered_df.to_csv(
            self.output_path,
            index=False,
            float_format='%.2e',
            na_rep='',
        )

        print(f'Wrote per-sample CSV results to {self.output_path}')


@ex.automain
def main() -> None:
    runner = ComparisonRunner()  # type: ignore[arg-type]
    outcomes = runner()
    runner.report(outcomes)
