from typing import Callable, Sequence, Tuple

import grain.python as grain
import jax
import jax.numpy as jnp

from egxc.dataloading.datasets.base import RawSample, Targets
from egxc.discretization import (
    GTOBasis,
    GTOGridEvalFn,
    GTOPreloader,
    PreloadedGTOBasis,
    QuadratureGridFn,
)
from egxc.systems import Grid, System
from egxc.systems.preload import PreloadSystem, preload_system_using_pyscf
from egxc.utils.typing import (
    Alignment,
    BaseInitialGuess,
    Float2xBxB,
    FloatBxB,
    NpFloatAx3,
    PRNGKey,
    cast_to_integer_tuple,
)


class PreloadTransform(grain.MapTransform):
    def __init__(
        self,
        basis: str,
        spin_restricted: bool,
        alignment: Alignment,
        use_density_fitting: bool,
        center: bool,
        base_initial_density_guess: BaseInitialGuess,
        vec_basis_fn_factory: GTOPreloader,
    ):
        self.basis = basis
        self.spin_restricted = spin_restricted
        self.alignment = alignment
        self.use_density_fitting = use_density_fitting
        self.center = center
        self.base_initial_density_guess: BaseInitialGuess = base_initial_density_guess
        self.vec_basis_fn_factory = vec_basis_fn_factory

    def map(  # type: ignore
        self, raw_sample: RawSample
    ) -> Tuple[PreloadSystem, PreloadedGTOBasis, Targets]:
        idx, (nuc_pos, atom_z, charge, spin, reference_density), targets = raw_sample
        pvec_basis_fns = self.vec_basis_fn_factory(
            atom_z
        )  # TODO: where to assign the max basis fns
        psys = preload_system_using_pyscf(
            int(idx),
            nuc_pos,
            atom_z,
            charge=charge,
            spin=spin,
            basis=self.basis,
            reference_density=reference_density,
            spin_restricted=self.spin_restricted,
            alignment=self.alignment,
            base_initial_density_guess=self.base_initial_density_guess,
            center=self.center,
            use_density_fitting=self.use_density_fitting,
        )
        return psys, pvec_basis_fns, targets


def get_preload_transform(
    batch_size: int,
    basis: str,
    spin_restricted: bool,
    alignment: Alignment,
    use_density_fitting: bool,
    base_initial_density_guess: BaseInitialGuess,
    center: bool,
    basis_fn_preloader: GTOPreloader,
) -> Sequence[grain.Transformation]:
    preload_transform = PreloadTransform(
        basis=basis,
        spin_restricted=spin_restricted,
        alignment=alignment,
        use_density_fitting=use_density_fitting,
        center=center,
        base_initial_density_guess=base_initial_density_guess,
        vec_basis_fn_factory=basis_fn_preloader,
    )

    if batch_size > 1:
        transformations = (
            preload_transform,
            grain.Batch(
                batch_size=batch_size,
                drop_remainder=False,
            ),
        )
    else:
        transformations = (preload_transform,)

    return transformations


DensityMatrices = (
    Tuple[FloatBxB]
    | Tuple[Float2xBxB]
    | Tuple[FloatBxB, FloatBxB]
    | Tuple[Float2xBxB, Float2xBxB]
)
ToJaxTransform = Callable[
    [PreloadSystem, PreloadedGTOBasis], Tuple[DensityMatrices, System]
]


def get_jax_transform(
    grid_fn: QuadratureGridFn,
    basis_fn: GTOGridEvalFn,
    # fock_tensors_fn: Callable | None = None,  # TODO: implement integral engine
) -> ToJaxTransform:
    def compute_grid(
        nuc_pos: NpFloatAx3,
        atom_z: Tuple[int, ...],
        basis_fns: GTOBasis,
    ) -> Grid:
        """
        Computes the grid and atomic orbitals for the system on the GPU
        """
        coords, weights = grid_fn(nuc_pos, atom_z)  # jitted
        aos = basis_fn(  # jitted
            coords,
            nuc_pos,  # type: ignore
            basis_fns.radial_primitives,
            basis_fns.compile_statics,
        )
        if isinstance(aos, tuple):
            aos, grad_aos = aos
            return Grid.create(coords, weights, aos, grad_aos)
        else:
            return Grid.create(coords, weights, aos, None)

    def input_transform(
        psys: PreloadSystem, preloaded_basis_fns: PreloadedGTOBasis
    ) -> Tuple[DensityMatrices, System]:
        basis_fns = GTOBasis.from_preloaded(preloaded_basis_fns)
        grid = compute_grid(
            psys.nuc_pos,
            cast_to_integer_tuple(psys.atom_z),
            basis_fns,
        )
        sys = System.from_preloaded(psys, grid=grid)
        initial_density_matrices = jax.tree.map(
            lambda x: jnp.asarray(x),
            psys.initial_density_matrices,
        )
        return initial_density_matrices, sys  # type: ignore

    return input_transform


InitialDensityMatrixFn = Callable[
    [DensityMatrices, PRNGKey | None], Tuple[FloatBxB | Float2xBxB, PRNGKey | None]
]


def get_initial_density_matrix_fn(
    min_ref_density_interpolation: float,
    max_ref_density_interpolation: float,
    noise_eps: float,
) -> InitialDensityMatrixFn:
    """Create a function for generating initial density matrices.

    Args:
        min_ref_density_interpolation: Minimum value for the interpolation factor when
            ``random_key`` is provided.
        max_ref_density_interpolation: Maximum value for the interpolation factor
        noise_eps: Standard deviation of the noise to be multiplied.
    """

    @jax.jit
    def initial_density_matrix_fn(
        initial_density_matrices: DensityMatrices, random_key: PRNGKey | None
    ) -> Tuple[FloatBxB | Float2xBxB, PRNGKey | None]:
        """
        Generate an initial density matrix for the system.
        The density matrix is produced by the mean field object and may be
        perturbed by noise.
        Args:
            initial_density_matrices: Initial density matrices to be used.
                If only one density matrix is provided, it is returned unchanged.
                If two density matrices are provided, they are interpolated between using an
                interpolation factor. The first density matrix is the base intialization method
                e.g. 'minao' and the second density matrix is the reference density matrix e.g.
                'LDA / def2-SVP'.
            random_key: PRNG key or ``None``. When ``None``, no noise is added
                and the interpolation factor is fixed to ``max_ref_density_interpolation``.
        """
        assert min_ref_density_interpolation >= 0
        assert min_ref_density_interpolation <= max_ref_density_interpolation
        assert max_ref_density_interpolation <= 1
        if len(initial_density_matrices) == 1:
            P_out = initial_density_matrices[0]
        else:
            assert len(initial_density_matrices) == 2, (
                'Only two density matrices are supported'
            )
            if random_key is not None:
                random_key, split = jax.random.split(random_key)
                beta = jax.random.uniform(
                    split,
                    (1,),
                    minval=min_ref_density_interpolation,
                    maxval=max_ref_density_interpolation,
                )
            else:
                beta = max_ref_density_interpolation
            P_out = (1.0 - beta) * initial_density_matrices[
                0
            ] + beta * initial_density_matrices[1]

        if noise_eps > 0 and random_key is not None:
            random_key, split = jax.random.split(random_key)
            noise = noise_eps * jax.random.normal(split, P_out.shape, dtype=P_out.dtype)
            noise = (noise + noise.T) / 2
            P_out *= 1 + noise  # perturbation proportional to signal
        return P_out, random_key

    return initial_density_matrix_fn
