import logging
import os
from typing import Any, Dict, Tuple

import jax
import jax.numpy as jnp
import numpy as onp
from flax.struct import dataclass
from numpy.typing import NDArray

from deixc import orbital_transforms
from egxc.dataloading import io
from egxc.dataloading.datasets.base import (
    BaseDataset,
    PartiallySplitDataset,
    PresplitDataset,
    RawInput,
    SupportsIndex,
    UnsplitDataset,
)
from egxc.utils.linalg import coeffs_to_density_matrix
from egxc.utils.pad import calc_padding_size
from egxc.utils.typing import (
    Float1,
    FloatAx3,
    FloatB,
    FloatBxB,
    FloatRefSCF,
    FloatRefSCFxBxB,
    FloatRefSCFxOxV,
    MethodKey,
    UIntB,
)


def pad_or_crop_scf_trajectory(tensor: NDArray, padding_size: int) -> NDArray:
    """
    Pads or crops the leading dimension of the tensor to the target length.
    If padding_size is positive, pads the tensor with zeros.
    If padding_size is negative, crops the tensor from the end.
    """
    if padding_size > 0:
        padding_shape = ((0, padding_size), *((0, 0),) * (tensor.ndim - 1))
        out = onp.pad(tensor, padding_shape, mode='edge')
    else:
        out = tensor[-padding_size:]
    return onp.asarray(out)


@dataclass
class DEIXCTargets:
    """
    DEI-XC targets for a single sample.

    mo_coeffs (FloatRefSCFxBxB):                Molecular orbital coefficients, (a.k.a. C)
    total_energies (FloatRefSCF):               Total energies along the RefSCF trajectory
    xc_energies (FloatRefSCF):                  XC energies along the RefSCF trajectory
    xc_potential_matrices (FloatRefSCFxBxB):    XC potentials along the RefSCF trajectory
    linear_response_xc_pot (FloatRefSCFxOxV):   Linear response of the XC potential along the
        normalized direction of steepest descent from direct minimization of the XC energy.
    orbital_energies (FloatB):                  Orbital energies at the final RefSCF iterate
        sorted ascendingly.
    forces (FloatAx3):                          Forces on nuclei.

    TODO: This class is presently limited to spin-restricted DFT calculations.
    """

    # Along SCF trajectory:
    mo_coeffs: FloatRefSCFxBxB  # a.k.a. C  (SCF, B, B)
    total_energies: FloatRefSCF
    xc_energies: FloatRefSCF
    xc_potential_matrices: FloatRefSCFxBxB
    linear_response_xc_pot: FloatRefSCFxOxV  # (SCF, O, V)
    orbital_energies: FloatB
    # From ground-state density only:
    forces: FloatAx3 | None

    @classmethod
    def create(
        cls,
        data: Dict[str, Any],
        align_scf_trajectory: int | None,
        shift_dispersion: bool,
    ) -> 'DEIXCTargets':
        """
        Construct a DEIXCTargets instance from raw data, with optional SCF trajectory alignment and dispersion shifting.

        Args:
            data (Dict[str, Any]): Dictionary containing all required target arrays,
            align_scf_trajectory (int | None):
                If not None, all arrays that are indexed along the SCF trajectory (i.e., with leading dimension SCF)
                will be padded or cropped so that their leading dimension matches `align_scf_trajectory`.
                - If the original SCF trajectory is shorter than `align_scf_trajectory`, the arrays are padded at the end
                  using the edge value (i.e., the last value is repeated).
                - If the original SCF trajectory is longer, the arrays are cropped from the end to match the target length.
                - This is useful for batching and alignment across samples with different numbers of SCF cycles,
                  ensuring all samples have a consistent SCF trajectory length for downstream processing.
                - If None, no padding or cropping is performed and the original SCF trajectory length is preserved.
            shift_dispersion (bool):
                If True, subtract D3 dispersion energy and forces from the total energies, XC energies, and forces, respectively.

        Returns:
            DEIXCTargets: An instance with all fields properly aligned and optionally dispersion-corrected.
        """
        if shift_dispersion:
            total_energies = data['total_energies'] - data['d3_dispersion_energy']
            xc_energies = data['xc_energies'] - data['d3_dispersion_energy']
            # forces = data['forces'] - data['d3_dispersion_forces']
        else:
            total_energies = data['total_energies']
            xc_energies = data['xc_energies']
            # forces = data['forces']

        if align_scf_trajectory is not None:
            RawSCF_cycles = data['mo_coeffs'].shape[0]
            PadSCF = calc_padding_size(RawSCF_cycles, align_scf_trajectory)
            if RawSCF_cycles + PadSCF > align_scf_trajectory:
                PadSCF = align_scf_trajectory - RawSCF_cycles
            mo_coeffs = pad_or_crop_scf_trajectory(data['mo_coeffs'], PadSCF)
            total_energies = pad_or_crop_scf_trajectory(total_energies, PadSCF)
            xc_energies = pad_or_crop_scf_trajectory(xc_energies, PadSCF)
            xc_potential_matrices = pad_or_crop_scf_trajectory(
                data['xc_potential_matrices'], PadSCF
            )
            linear_response_xc_pot = pad_or_crop_scf_trajectory(
                data['linear_response_xc_pot'], PadSCF
            )
        else:
            mo_coeffs = onp.asarray(data['mo_coeffs'])
            total_energies = onp.asarray(total_energies)
            xc_energies = onp.asarray(xc_energies)
            xc_potential_matrices = onp.asarray(data['xc_potential_matrices'])
            linear_response_xc_pot = onp.asarray(data['linear_response_xc_pot'])

        out = cls(
            mo_coeffs=mo_coeffs,  # type: ignore
            total_energies=total_energies,  # type: ignore
            xc_energies=xc_energies,  # type: ignore
            xc_potential_matrices=xc_potential_matrices,  # type: ignore
            linear_response_xc_pot=linear_response_xc_pot,  # type: ignore
            orbital_energies=onp.sort(data['orbital_energies']).astype(onp.float64),  # type: ignore
            forces=None,
        )
        return out

    def __repr__(self) -> str:
        return (
            f'DEIXCTargets:\n'
            f'   mo_coeffs: \t\t{self.mo_coeffs.shape}\n'
            f'   total_energies: \t{self.total_energies[-1]:.2f} Ha\n'
            f'   xc_energies: \t{self.xc_energies[-1]:.2f} Ha\n'
            f'   v_xcs: \t\t{self.xc_potential_matrices.shape}\n'
            f'   linear_response_xc_pot: {self.linear_response_xc_pot.shape}\n'
            f'   forces: \t\t{self.forces.shape if self.forces is not None else "None"}\n'
        )

    @property
    def n_occ(self) -> int:
        return self.linear_response_xc_pot.shape[1]

    @property
    def occupied_virtual_shape(self) -> Tuple[int, int]:
        return self.linear_response_xc_pot[0].shape  # type: ignore

    def cycles_to_convergence(self, convergence_threshold: float) -> int:
        delta_total_pred = onp.abs(self.total_energies[1:] - self.total_energies[:-1])
        converged_mask = delta_total_pred < convergence_threshold
        assert onp.any(converged_mask), 'reference SCF trajectory did not converge'
        out = onp.argmax(converged_mask) + 1  # 1-based index of first converged step
        return int(out)

    @property
    def homo_lumo_gap(self) -> float:
        out = float(
            self.orbital_energies[self.n_occ] - self.orbital_energies[self.n_occ - 1]
        )
        return float(out)

    @property
    def n_basis_functions(self) -> int:
        return self.mo_coeffs.shape[1]

    @property
    def occupancies(self) -> UIntB:
        # TODO: assumes spin-restricted calculations
        out = jnp.zeros(self.n_basis_functions, dtype=jnp.uint8)
        out = out.at[: self.n_occ].set(2)
        return out

    @property
    def density_matrices(self) -> FloatRefSCFxBxB:
        out = jax.vmap(coeffs_to_density_matrix, in_axes=(0, None))(
            self.mo_coeffs, self.occupancies
        )
        return jax.vmap(lambda dm: 0.5 * (dm + dm.T))(out)  #  symmetrize

    @property
    def density_matrix(self) -> FloatBxB:
        return self.density_matrices[-1]

    @property
    def xc_potential_matrix(self) -> FloatBxB:
        return self.xc_potential_matrices[-1]

    @property
    def total_energy(self) -> Float1:
        return self.total_energies[-1]

    @property
    def xc_energy(self) -> Float1:
        return self.xc_energies[-1]

    def get_orbital_rotation_gradient(
        self, fock_or_Vxc_matrices: FloatRefSCFxBxB
    ) -> FloatRefSCFxOxV:
        """
        Depending on the input, this method computes the occupied-virtual gradient direction on the Grassmann manifold,
        corresponding to steepest descent for minimization of the total energy or XC energy.
        """
        out = jax.vmap(
            orbital_transforms.dm_gradient_to_orbital_rotation_gradient,
            in_axes=(0, 0, None),
        )(fock_or_Vxc_matrices, self.mo_coeffs, self.n_occ)
        return out

    def get_perturbations_from_direct_minimization_given_potential_matrices(
        self,
        fock_or_Vxc_matrices: FloatRefSCFxBxB,
        overlap: FloatBxB,
        normalize: bool,
    ) -> FloatRefSCFxBxB:
        """
        Compute the AO-basis density matrix perturbations along the occupied-virtual (Grassmann) minimization directions.
        Args:
            fock_or_Vxc_matrices: Fock or XC potential matrices (SCF, B, B)
            overlap: AO overlap matrix (B, B)
            normalize: If True, each perturbation direction is normalized before transformation

        Returns:
            Array of AO-basis perturbation matrices, shape (SCF, B, B)
        """
        directions = self.get_orbital_rotation_gradient(fock_or_Vxc_matrices)
        out = jax.vmap(
            orbital_transforms.ao_density_perturbation_from_occupied_virtual_rotation,
            in_axes=(0, 0, None, None, None),
        )(
            directions,
            self.mo_coeffs,
            overlap,
            self.n_occ,
            normalize,
        )
        return out

    def get_linear_response_xc_pot_in_ao_basis(
        self, overlap: FloatBxB
    ) -> FloatRefSCFxBxB:
        out = jax.vmap(orbital_transforms.mo_to_ao, in_axes=(0, 0, None, None))(
            self.linear_response_xc_pot, self.mo_coeffs, overlap, self.n_occ
        )
        return out


RawSample = Tuple[SupportsIndex, RawInput, DEIXCTargets]


class DEIXCDataset(BaseDataset):
    """Dataset wrapper that provides DEI-XC auxiliary targets."""

    _dataset: BaseDataset
    deixc_ref_method_key: MethodKey
    deixc_ref_method_kwargs: Dict[str, Any]
    align_scf_trajectory: int
    shift_dispersion: bool

    def __init__(
        self,
        dataset: BaseDataset,
        method_key: MethodKey,
        method_kwargs: Dict[str, Any],
        align_scf_trajectory: int,
        shift_dispersion: bool,
    ) -> None:
        """
        Args:
            dataset: The dataset to wrap.
            method_key: The method to use.
            method_kwargs: Additional kwargs for the method.
            shift_dispersion: Whether to shift the dispersion energy and forces.
            align_scf_trajectory: The number of SCF cycles to align the trajectory to.
                If None, the trajectory is not aligned.
                Else it is padded or cropped to the given number of cycles.
        """
        self._dataset = dataset
        self.deixc_ref_method_key = method_key
        self.deixc_ref_method_kwargs = method_kwargs
        self.align_scf_trajectory = align_scf_trajectory
        self.shift_dispersion = shift_dispersion
        if shift_dispersion:
            logging.info('Shifted dispersion energy and forces')
        self.copy_params_from_dataset(dataset)

    def infer_split(  # type: ignore[override]
        self,
        train_fraction: float | None = None,
        val_fraction: float | None = None,
        data_split_seed: int = 0,
    ) -> Tuple['DEIXCDataset', 'DEIXCDataset', 'DEIXCDataset']:
        if isinstance(self._dataset, PresplitDataset):
            assert train_fraction is None and val_fraction is None
            train_ds, val_ds, test_ds = self._dataset.split()
        elif isinstance(self._dataset, PartiallySplitDataset):
            assert train_fraction is None and val_fraction is not None
            train_ds, val_ds, test_ds = self._dataset.random_split(
                val_fraction, data_split_seed
            )
        elif isinstance(self._dataset, UnsplitDataset):
            assert train_fraction is not None and val_fraction is not None
            train_ds, val_ds, test_ds = self._dataset.random_split(
                train_fraction, val_fraction, data_split_seed
            )
        else:
            raise ValueError(f'Unknown dataset type: {type(self._dataset)}')
        train_ds = DEIXCDataset(
            train_ds,
            self.deixc_ref_method_key,
            self.deixc_ref_method_kwargs,
            self.align_scf_trajectory,
            self.shift_dispersion,
        )
        val_ds = DEIXCDataset(
            val_ds,
            self.deixc_ref_method_key,
            self.deixc_ref_method_kwargs,
            self.align_scf_trajectory,
            self.shift_dispersion,
        )
        test_ds = DEIXCDataset(
            test_ds,
            self.deixc_ref_method_key,
            self.deixc_ref_method_kwargs,
            self.align_scf_trajectory,
            self.shift_dispersion,
        )
        return train_ds, val_ds, test_ds

    def __len__(self) -> int:
        return len(self._dataset)

    @property
    def auxiliary_data_directory(self) -> str:  # type: ignore[override]
        return self._dataset.auxiliary_data_directory

    def __getitem__(self, idx: SupportsIndex) -> RawSample:  # type: ignore[override]
        idx0, raw_input, _ = self._dataset[
            idx
        ]  # use energies and forces consistent with the generated data used here
        aux_dir = io.auxiliary_data_directory(
            self.auxiliary_data_directory,
            'deixc',
            self.deixc_ref_method_key,
            **self.deixc_ref_method_kwargs,
        )
        path = os.path.join(aux_dir, f'{idx0}.npz')
        data = onp.load(path, allow_pickle=True)
        deixc_targets = DEIXCTargets.create(
            data, self.align_scf_trajectory, self.shift_dispersion
        )
        return idx0, raw_input, deixc_targets
