import os
from collections import defaultdict
from typing import Dict, List, Literal

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from deixc.training.loss.base import (
    density_hessian_diagonal_loss,
    force_loss,
    orbital_rotation_direction_loss,
    xc_potential_linear_response_loss,
    xc_potential_loss,
)
from pyscf import dft, gto

from deixc.data_generation.hessian import linear_response
from deixc.dataset import DEIXCTargets, RawSample
from egxc import dataloading
from egxc.dataloading import io
from egxc.dataloading.datasets.base import BaseDataset
from egxc.discretization import (
    GTOBasis,
    get_grid_fn,
    get_gto_grid_eval_fn,
    get_gto_preloader,
)
from egxc.systems import Grid, System
from egxc.systems.preload import preload_system_using_pyscf
from egxc.training.utils.loss_components import energy_loss
from egxc.utils.typing import Alignment, ElectRepTensorType, cast_to_integer_tuple

# JAX config for deterministic CPU, double precision
jax.config.update('jax_platforms', 'cpu')
jax.config.update('jax_enable_x64', True)


# Configuration: edit these constants as needed
DATA_DIR = 'ANONYMOUS_DIR'
DATASET = 'qm9'  # 'qm9' | 'md17' | 'des370k'
HEAVY_ATOMS_THRESH = 6
N_SAMPLES = 10
BASIS = 'def2-TZVPD'
METHODS = ['dft_LDA', 'dft_SCAN', 'dft_B3LYP']  # , 'dft_B3LYP', 'dft_wB97M-V']
REFS = ['dft_B3LYP']  # ['dft_wB97M-V', 'dft_B3LYP']
GRID_LEVEL = 1
INTEGRAL_NORM = 'L1'  # 'L1' | 'L2'
ERT_TYPE = ElectRepTensorType.DENSITY_FITTED


def from_reference_density(
    mol: gto.Mole,
    xc_str: str,
    fixed_density: onp.ndarray,
    ref_mo_coeffs: onp.ndarray,
    occ: onp.ndarray,
) -> DEIXCTargets:
    """Compute DeiXCTargets evaluating a functional at a fixed reference density.

    Parameters
    ----------
    mol: pyscf.gto.Mole
        Molecule.
    xc_str: str
        XC functional label for PySCF (e.g., 'b3lyp', 'scan', 'wb97m-v').
    fixed_density: np.ndarray (B, B)
        AO density matrix to evaluate at.
    ref_mo_coeffs: np.ndarray (B, B)
        Reference MO coefficients to define occupied/virtual subspace.
    """
    mf = dft.RKS(mol, xc=xc_str)
    mf.mo_occ = occ
    # XC potential and Fock at fixed density
    mf.grids.level = 1
    mf.grids.build()

    # Energies at fixed density
    temp = mf.xc
    e_tot = mf.energy_tot(dm=fixed_density)
    F_tot = mf.get_fock(dm=fixed_density)
    mf.xc = ''
    e_no_xc = mf.energy_tot(dm=fixed_density)
    F_no_xc = mf.get_fock(dm=fixed_density)
    mf.xc = temp
    e_xc_val = e_tot - e_no_xc
    v_xc = F_tot - F_no_xc

    # Shapes as 1-step SCF arrays
    mo_coeffs = onp.asarray([ref_mo_coeffs])
    fock_matrices_along_scf = onp.asarray([F_tot])
    total_energies = onp.asarray([e_tot])
    xc_energies = onp.asarray([e_xc_val])
    xc_potential_matrices = onp.asarray([v_xc])

    # Linear response and Hessian diagonal
    gradients, linear_responses_xc_pot, hessian_diagonal = linear_response.get(
        mf, mo_coeffs, fock_matrices_along_scf
    )

    # Forces (attempt at fixed density; fallback if unsupported)
    try:
        forces = mf.nuc_grad_method().kernel(dm=fixed_density)  # type: ignore
    except Exception:
        forces = mf.nuc_grad_method().kernel()

    out = DEIXCTargets(
        mo_coeffs=jnp.asarray(mo_coeffs, dtype=jnp.float64),
        total_energies=jnp.asarray(total_energies, dtype=jnp.float64),
        xc_energies=jnp.asarray(xc_energies, dtype=jnp.float64),
        xc_potential_matrices=jnp.asarray(xc_potential_matrices, dtype=jnp.float64),
        linear_response_xc_pot=jnp.asarray(linear_responses_xc_pot, dtype=jnp.float64),
        density_hessian_diagonal=jnp.asarray(hessian_diagonal, dtype=jnp.float64),
        forces=jnp.asarray(forces, dtype=jnp.float64),
    )
    return out


def with_dft_from_reference(
    sample: RawSample,
    fixed_density: onp.ndarray,
    ref_mo_coeffs: onp.ndarray,
    occ: onp.ndarray,
    xc_str: str,
    basis: str,
) -> DEIXCTargets:
    idx, (nuc_pos, atom_z, charge, spin, _), _ = sample
    print(
        f'##### Computing idx {idx} with DFT functional {xc_str} at reference density #####',
        flush=True,
    )
    mol = gto.M(
        atom=list(zip(atom_z, nuc_pos)),
        basis=basis,
        spin=spin,
        charge=charge,
    )
    targets = from_reference_density(mol, xc_str, fixed_density, ref_mo_coeffs, occ)
    return targets


def build_dataset(
    dataset_key: str, data_dir: str, heavy_atoms_thresh: Literal['debug'] | int
) -> BaseDataset:
    if dataset_key == 'qm9':
        # Keep fluorine by default consistent with data_gen script
        ds = dataloading.QM9(
            data_dir=data_dir,
            heavy_atoms_thresh=heavy_atoms_thresh,
            exclude_fluorine=False,
        )
        ds, _, _ = ds.random_split(val_fraction=0.0, seed=0)
        return ds
    elif dataset_key == 'md17':
        # Example: ethanol train split
        return dataloading.MD17(name='ethanol', data_dir=data_dir, train=True)
    elif dataset_key == 'des370k':
        return dataloading.DES370K(data_dir=data_dir)
    else:
        raise ValueError(f'Unknown dataset key: {dataset_key}')


def load_targets_for_method(
    aux_root: str, method: str, basis: str, idx: int
) -> DEIXCTargets:
    aux_dir = io.auxiliary_data_paths(aux_root, method, basis, 'deixc_targets')
    path = os.path.join(aux_dir, f'{idx}.npz')
    if not os.path.exists(path):
        raise FileNotFoundError(f'{path} not found')
    data = onp.load(path, allow_pickle=True)
    # Do not align SCF trajectories, we only use final values for losses here
    return DEIXCTargets.create(data, align_scf_trajectory=None, shift_dispersion=False)


def build_system(raw_input, basis: str) -> System:
    nuc_pos, atom_z, charge, spin, aux_data = raw_input
    psys = preload_system_using_pyscf(
        idx=0,
        nuc_pos=nuc_pos,
        atom_z=atom_z,
        charge=charge,
        spin=spin,
        aux_data=aux_data,
        basis=basis,
        spin_restricted=True,
        ert_type=ERT_TYPE,
        alignment=Alignment(atom=1, basis=1, grid=1),
        base_initial_density_guess='minao',
        cache_pyscf_mole=False,
    )
    # Local equivalent of test helper: build grid and basis fns
    grid_fn = get_grid_fn(1, set(psys.atom_z), 1)
    coords, weights = grid_fn(
        psys.nuc_pos[psys.atom_mask], cast_to_integer_tuple(psys.atom_z[psys.atom_mask])
    )
    l_max, pvec_basis_fn_factory = get_gto_preloader(
        basis, set(psys.atom_z[psys.atom_mask])
    )
    preloaded_vec_basis_fns = pvec_basis_fn_factory(psys.atom_z[psys.atom_mask])
    vec_basis_fns = GTOBasis.from_preloaded(preloaded_vec_basis_fns)
    basis_fn = get_gto_grid_eval_fn(1, l_max)
    aos, grad_aos = basis_fn(
        coords,
        psys.nuc_pos[psys.atom_mask],
        vec_basis_fns.primitives,
        vec_basis_fns.compile_statics,
    )
    sys = System.from_preloaded(psys, Grid.create(coords, weights, aos, grad_aos))
    return sys


def _method_to_xc_label(method: str) -> str:
    return method.split('_', 1)[1] if method.lower().startswith('dft_') else method


def predict_with_generation(
    method: str, basis: str, ref: DEIXCTargets, sample
) -> DEIXCTargets:
    xc_label = _method_to_xc_label(method)
    fixed_dm = onp.asarray(ref.density_matrix)
    ref_C = onp.asarray(ref.mo_coeffs[-1])
    occ = onp.asarray(ref.occupancies)
    return with_dft_from_reference(sample, fixed_dm, ref_C, occ, xc_label, basis)


def compute_component_losses(
    pred: DEIXCTargets, ref: DEIXCTargets, sys: System
) -> Dict[str, float]:
    n_electrons = sys.n_electrons

    L_energy = energy_loss(ref.xc_energy, pred.xc_energy, n_electrons)

    L_forces = force_loss(ref.forces, pred.forces)

    L_vxc = xc_potential_loss(
        ref.xc_potential_matrix,
        pred.xc_potential_matrix,
        sys.grid,
        n_electrons,
        reference_basis_is_same=True,
        integral_norm=INTEGRAL_NORM,  # type: ignore[arg-type]
    )

    L_orb = orbital_rotation_direction_loss(
        ref.xc_based_minimization_directions[-1],
        pred.xc_potential_matrix,
        ref.mo_coeffs[-1],
        int(ref.n_occ),
        reference_basis_is_same=True,
    )

    ref_lin = ref.get_linear_response_xc_pot_in_ao_basis(sys.fock_tensors.overlap)[-1]
    pred_lin = pred.get_linear_response_xc_pot_in_ao_basis(sys.fock_tensors.overlap)[-1]
    L_lin = xc_potential_linear_response_loss(
        ref_lin,
        pred_lin,
        sys.grid,
        n_electrons,
        reference_basis_is_same=True,
        integral_norm=INTEGRAL_NORM,  # type: ignore[arg-type]
    )

    L_hdiag = density_hessian_diagonal_loss(
        ref.density_hessian_diagonal,
        pred.density_hessian_diagonal,
        reference_basis_is_same=True,
    )

    return {
        'energy': float(L_energy.item()),
        'forces': float(L_forces.item()),
        'xc_potential': float(L_vxc.item()),
        'orbital_rotation_direction': float(L_orb.item()),
        'xc_potential_linear_response': float(L_lin.item()),
        'density_hessian_diagonal': float(L_hdiag.item()),
    }


def main():
    dataset = build_dataset(DATASET, DATA_DIR, HEAVY_ATOMS_THRESH)
    aux_root = dataset.auxiliary_data_directory

    methods: List[str] = METHODS
    refs: List[str] = REFS

    # Accumulators: {ref: {method: {component: sum}}}, counts: {ref: {method: n}}
    sums: Dict[str, Dict[str, Dict[str, float]]] = defaultdict(
        lambda: defaultdict(lambda: defaultdict(float))
    )

    for idx in range(N_SAMPLES):
        sample = dataset[idx]
        idx, raw_input, _ = sample

        # Build System once per sample
        sys = build_system(raw_input, BASIS)

        # Load references per sample
        ref_targets: Dict[str, DEIXCTargets] = {}
        for ref in refs:
            t = load_targets_for_method(aux_root, ref, BASIS, int(idx))
            ref_targets[ref] = t

        # Compare each method to each ref using generation routine predictions
        for ref in refs:
            v_xc_ref = ref_targets[ref].xc_potential_matrix
            v_xcs = []
            for method in methods:
                pred_targets = predict_with_generation(
                    method, BASIS, ref_targets[ref], sample
                )
                comp = compute_component_losses(pred_targets, ref_targets[ref], sys)
                print(
                    f'## energy offset:{ref} {method} {ref_targets[ref].xc_energy / pred_targets.xc_energy}'
                )
                v_xcs.append(pred_targets.xc_potential_matrix)
                for k, v in comp.items():
                    sums[ref][method][k] += v
                print(f'{ref} {method} {comp}')

            # Plot all v_xc matrices as 2D images
            N_methods = len(methods)
            fig, axes = plt.subplots(1, N_methods, figsize=(4.5 * N_methods, 4))
            v_xcs = onp.array(v_xcs)
            delta = v_xcs - v_xc_ref
            min_delta = delta.min()
            max_delta = delta.max()
            for m in range(N_methods):
                ax = axes[m]
                im = ax.imshow(
                    delta[m],
                    aspect='auto',
                    cmap='viridis',
                    vmin=min_delta,
                    vmax=max_delta,
                )
                ax.set_title(f'Error {methods[m]}')
                plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            plt.tight_layout()
            plt.savefig(f'./scripts/plots/{ref}_{idx}.png')
            plt.close()

    # Print results
    print('\n ######### Mean losses per method vs reference (over samples):', flush=True)
    for ref in refs:
        print(f'\nReference: {ref}')
        for method in methods:
            comp_sums = sums[ref][method]
            means = {k: comp_sums[k] / N_SAMPLES for k in sorted(comp_sums.keys())}
            print(f'\t{method}:')
            for k, v in means.items():
                print(f'\t\t{k}: {v:.2e}, relative: {v / means["energy"]:.2e}')


if __name__ == '__main__':
    main()
