from dataclasses import dataclass
from types import SimpleNamespace
from typing import Annotated, Any, Dict, Literal, Tuple

import numpy as onp
from flax.typing import FrozenVariableDict
from jaxtyping import Array, Bool, Float, Int, PyTree, UInt
from numpy.typing import NDArray

### TYPING LEGEND: ###
# N: grid points
# A: atoms
# Z: elements in the periodic table
# B: basis functions  = O + V
# O: occupied orbitals
# V: virtual orbitals
# E: electrons
# G: Gaussians for a basis function
# RP: Radial primitives
# S: Shells
# C_SPH: cartesian spherical harmonics
# M_SPH: real spherical harmonics
# Q: density fitting auxiliary basis
# SCF: number of scf iterations
# RefSCF: number of SCF iterations in the reference calculation
# F: node features (atom features)
# T: Wildcard / Generic


NnParams = PyTree | FrozenVariableDict | Dict[str, Any]

BaseInitialGuess = Literal['minao', 'atom', 'hcore']
RBFType = Literal['trigonometric', 'polynomial', 'bessel', 'smooth_finite']

MethodKey = Literal['hf', 'ks_dft', 'ccsd']
AuxDataKey = Literal['deixc', 'initial_guess']
PRNGKey = Array

# dataloading
# typing definitions merged from src/deixc/typing.py
NpBoolA = Annotated[NDArray[onp.bool_], 'shape=(A,)']
NpBoolB = Annotated[NDArray[onp.bool_], 'shape=(B,)']
NpBoolQ = Annotated[NDArray[onp.bool_], 'shape=(Q,)']
NpUIntA = Annotated[NDArray[onp.uint8], 'shape=(A,)']
NpUIntT = Annotated[NDArray[onp.uint16], 'shape=(T,)']
NpFloatAx3 = Annotated[NDArray[onp.float64], 'shape=(A, 3)']
NpUIntB = Annotated[NDArray[onp.uint8], 'shape=(B,)']
NpBool2xB = Annotated[NDArray[onp.bool_], 'shape=(2, B)']
NpFloatOV = Annotated[
    NDArray[onp.float64], 'shape=(OV,)'
]  # meaning occupied * virtual many entries
NpFloatBxB = Annotated[NDArray[onp.float64], 'shape=(B, B)']
NpFloatOVxOV = Annotated[
    NDArray[onp.float64], 'shape=(OV, OV)'
]  # meaning occupied * virtual^2 many entries, but in a matrix form
NpFloatRefSCFxBxB = Annotated[NDArray[onp.float64], 'shape=(RefSCF, B, B)']
NpFloatB = Annotated[NDArray[onp.float64], 'shape=(B,)']
NpFloatOxV = Annotated[NDArray[onp.float64], 'shape=(O, V)']
NpFloatRefSCFxOxV = Annotated[NDArray[onp.float64], 'shape=(RefSCF, O, V)']
NpFloat2xBxB = Annotated[NDArray[onp.float64], 'shape=(2, B, B)']
NpFloatQxBxB = Annotated[NDArray[onp.float64], 'shape=(Q, B, B)']
NpFloatBxBxBxB = Annotated[NDArray[onp.float64], 'shape=(B, B, B, B)']
NpFloatN = Annotated[NDArray[onp.float64], 'shape=(N,)']
NpFloatNx3 = Annotated[NDArray[onp.float64], 'shape=(N, 3)']
NpFloatNxB = Annotated[NDArray[onp.float64], 'shape=(N, B)']
NpFloatNxBx3 = Annotated[NDArray[onp.float64], 'shape=(N, B, 3)']
NpFloatRefSCF = Annotated[NDArray[onp.float64], 'shape=(RefSCF,)']

NpDensityMatrix = NpFloatBxB | NpFloat2xBxB

# Compile time static types
CompileStaticStr = str
CompileStaticInt = int

# General
UInt1 = UInt[Array, '1']
Float1 = Float[Array, '1']
Float1x1 = Float[Array, '1 1']
Float3 = Float[Array, '3']

# Structure related
BoolA = Bool[Array, 'A']
BoolN = Bool[Array, 'N']
IntA = Int[Array, 'A']
FloatA = Float[Array, 'A']
FloatAx3 = Float[Array, 'A 3']
FloatRefSCFxAx3 = Float[Array, 'RefSCF A 3']
FloatAxA = Float[Array, 'A A']
FloatAxAx1 = Float[Array, 'A A 1']
FloatAxAx3 = Float[Array, 'A A 3']
FloatAxAxRBF = Float[Array, 'A A RBF']
FloatAxN = Float[Array, 'A N']
FloatAxNx1 = Float[Array, 'A N 1']
FloatAxNx3 = Float[Array, 'A N 3']
FloatAxNx4 = Float[Array, 'A N 4']
FloatAxNxRBF = Float[Array, 'A N RBF']
IntT = Int[Array, 'T']
UIntB = UInt[Array, 'B']
BoolB = Bool[Array, 'B']
Bool2xB = Bool[Array, '2 B']
FloatB = Float[Array, 'B']
FloatBxE = Float[Array, 'B E']
Float2xBxE = Float[Array, '2 B E']
FloatBxB = Float[Array, 'B B']
Float2xBxB = Float[Array, '2 B B']
FloatTxBxB = Float[Array, 'T B B']
FloatTx2xBxB = Float[Array, 'T 2 B B']
FloatOV = Float[Array, 'O V']
FloatOxV = Float[Array, 'O V']
Float2xOxV = Float[Array, '2 O V']
FloatTxOxV = Float[Array, 'T O V']
FloatTx2xOxV = Float[Array, 'T 2 O V']
FloatT = Float[Array, 'T']
FloatTxT = Float[Array, 'T T']
FloatN = Float[Array, 'N']
FloatNx2 = Float[Array, 'N 2']
FloatNx3 = Float[Array, 'N 3']
FloatNx4 = Float[Array, 'N 4']
FloatNx7 = Float[Array, 'N 7']
FloatNxF = Float[Array, 'N F']
FloatNxT = Float[Array, 'N T']
FloatNxB = Float[Array, 'N B']
FloatNxBx3 = Float[Array, 'N B 3']
FloatNxBxB = Float[Array, 'N B B']
FloatQxBxB = Float[Array, 'Q B B']
FloatBxBxBxB = Float[Array, 'B B B B']
FloatSCF = Float[Array, 'SCF']
FloatRefSCF = Float[Array, 'RefSCF']
FloatSCFxSCF = Float[Array, 'SCF SCF']
FloatSCFxBxB = Float[Array, 'SCF B B']
FloatSCFx2xBxB = Float[Array, 'SCF 2 B B']
FloatRefSCFxBxB = Float[Array, 'RefSCF B B']
FloatRefSCFx2xBxB = Float[Array, 'RefSCF 2 B B']
FloatRefSCFxOxV = Float[Array, 'RefSCF O V']
FloatRefSCFx2xOxV = Float[Array, 'RefSCF 2 O V']

# GNN related
FloatAxF = Float[Array, 'A F']
FloatAxFx3 = Float[Array, 'A F 3']
FloatAxAx3xF = Float[Array, 'A A 3 F']

# Basis related
NpFloatG = Annotated[NDArray[onp.float64], 'shape=(G,)']
NpUIntG = Annotated[NDArray[onp.uint32], 'shape=(G,)']
FloatG = Float[Array, 'G']
FloatNxAxM_SPH = Float[Array, 'N A M_SPH']
NpBoolLxT = Annotated[NDArray[onp.bool_], 'shape=(L, T)']

NpFloatRP = Annotated[NDArray[onp.float64], 'shape=(RP,)']
NpUIntRP = Annotated[NDArray[onp.uint16], 'shape=(RP,)']
NpUIntS = Annotated[NDArray[onp.uint16], 'shape=(S,)']
NpFloatTxT = Annotated[NDArray[onp.float64], 'shape=(T, T)']

UIntRP = UInt[Array, 'RP']
FloatRP = Float[Array, 'RP']
FloatNxRP = Float[Array, 'N RP']
UIntP = UInt[Array, 'P']
FloatP = Float[Array, 'P']
FloatPx3 = Float[Array, 'P 3']
UIntPx3 = UInt[Array, 'P 3']


__HIGH = 'float64'
__LOW = 'float32'

PRECISION = SimpleNamespace(
    basis=__HIGH,
    forces=__HIGH,
    xc_energy=__HIGH,
    quadrature=__HIGH,  # also used in encoding
    solver=__HIGH,
    eri_tensor=__HIGH,
    loss=__HIGH,
    local_nn=__HIGH,  # used in energy density and reweighting
    gnn=__HIGH,
    graph_readout=__HIGH,
    decoding=__HIGH,
)


def cast_to_integer_tuple(array: NDArray) -> Tuple[int, ...]:
    return tuple(map(int, array))


@dataclass
class Alignment:
    atom: int = 1
    basis: int = 1
    grid: int = 1

    def __post_init__(self):
        assert self.atom > 0, 'Atom alignment must be greater than 0.'
        assert self.basis > 0, 'Basis alignment must be greater than 0.'
        assert self.grid > 0, 'Grid alignment must be greater than 0.'

    @property
    def is_aligned(self) -> bool:
        return self.atom > 1 or self.basis > 1 or self.grid > 1

    def __hash__(self) -> int:
        return hash((self.atom, self.basis, self.grid))
