import warnings
from typing import Any, Dict, Literal

from .base import BaseEnergyFunctional, EnergyDensityToFunctionalWrapper
from .classical import GGA, LDA, LSDA, Hybrid, MetaGGA, wB97M_V
from .classical.hybrid import e_xc_b3lyp
from .learnable import DEIXC, EGXC, Nagai2020, Nagai2022, Skala, XCDiff
from .learnable.egxc import NonLocalGridFeatureMode
from .learnable.nn import (
    BaseGNN,
    NequIP,
    Nequix,
    NumericDecoder,
    NumericEncoder,
    PaiNN,
    SpatialReweighting,
)


def _get_traditional_functional_by_name(
    name: str,
    spin_restricted: bool | None = None,
    use_density_fitting: bool | None = None,
) -> BaseEnergyFunctional:
    """
    Get a traditional functional by name with strict parameter handling.
    spin_restricted and use_density_fitting must be specified for hybrid functionals.
    """
    match name:
        case 'lda' | 'vwn5':
            return LDA(key='vwn5')
        case 'pz81':
            return LDA(key='pz81')
        case 'pw92_spin_restricted':
            return LDA(key='pw92_spin_restricted')
        case 'pw92':
            return LSDA(key='pw92')
        case 'pbe':
            return GGA(key='pbe')
        case 'b88':
            warnings.warn('Using B88 exchange and PBE correlation')
            return GGA(key='b88')
        case 'lyp':
            warnings.warn('Using PBE exchange and LYP correlation')
            return GGA(key='lyp')
        case 'local_b3lyp':
            return GGA(key='local_b3lyp')
        case 'scan':
            return MetaGGA(key='scan')
        case 'b3lyp':
            assert spin_restricted is not None and use_density_fitting is not None, (
                'must be specified for hybrid functionals'
            )
            return Hybrid('b3lyp', use_density_fitting, spin_restricted)
        case 'pbe0':
            assert spin_restricted is not None and use_density_fitting is not None, (
                'must be specified for hybrid functionals'
            )
            return Hybrid('pbe0', use_density_fitting, spin_restricted)
        case 'wb97m-v':
            assert spin_restricted is not None and use_density_fitting is not None, (
                'must be specified for hybrid functionals'
            )
            return wB97M_V(
                use_density_fitting=use_density_fitting,
                spin_restricted=spin_restricted,
            )
        case _:
            raise ValueError(f'Unknown traditional functional: {name}')


def _get_learnable_local_functional_by_name(
    name: str,
    local_hidden_dim: int | None = None,
    local_n_layers: int | None = None,
) -> BaseEnergyFunctional:
    # Handle size suffixes for learnable local functionals
    if '_xsmall' in name:
        assert local_hidden_dim is None
        local_hidden_dim = 16
        name = name.replace('_xsmall', '')
    elif '_small' in name:
        assert local_hidden_dim is None
        local_hidden_dim = 32
        name = name.replace('_small', '')
    elif '_medium' in name:
        assert local_hidden_dim is None
        local_hidden_dim = 64
        name = name.replace('_medium', '')
    elif '_large' in name:
        assert local_hidden_dim is None
        local_hidden_dim = 128
        name = name.replace('_large', '')

    match name:
        case 'nagai2020' | 'nnmgga':
            if local_hidden_dim is None:
                local_hidden_dim = 100
            if local_n_layers is None:
                local_n_layers = 4
            return Nagai2020(hidden_dim=local_hidden_dim, n_layers=local_n_layers)
        case 'nagai2020_orbital_free':
            if local_hidden_dim is None:
                local_hidden_dim = 100
            if local_n_layers is None:
                local_n_layers = 4
            return Nagai2020(
                hidden_dim=local_hidden_dim, n_layers=local_n_layers, orbital_free=True
            )
        case 'xcdiff_orbital_free':
            if local_hidden_dim is None:
                local_hidden_dim = 16
            if local_n_layers is None:
                local_n_layers = 4
            return XCDiff(
                hidden_dim=local_hidden_dim,
                n_layers=local_n_layers,
                orbital_free=True,
            )
        case 'dick2021' | 'xcdiff':
            if local_hidden_dim is None:
                local_hidden_dim = 16
            if local_n_layers is None:
                local_n_layers = 4
            return XCDiff(hidden_dim=local_hidden_dim, n_layers=local_n_layers)
        case 'nagai2022':
            if local_hidden_dim is None:
                local_hidden_dim = 100
            if local_n_layers is None:
                local_n_layers = 4
            return Nagai2022(hidden_dim=local_hidden_dim, n_layers=local_n_layers)
        case 'skala_mgga':
            if local_hidden_dim is None:
                local_hidden_dim = 256
            if local_n_layers is None:
                local_n_layers = 6
            return Skala(hidden_dim=local_hidden_dim, n_layers=local_n_layers)
        case _:
            raise ValueError(f'Unknown learnable local functional: {name}')


def _get_learnable_non_local_functional_by_name(
    name: str,
    with_graph_readout: bool,
    non_local_grid_feature_mode: NonLocalGridFeatureMode,
    local_model: BaseEnergyFunctional | None = None,
) -> BaseEnergyFunctional:
    match name:
        case 'egxc2024':
            irreps_str = '32x0e + 32x1o + 32x2e'
            output_irreps_str = '16x0e + 16x1o + 16x2e'
            spatial_feature_dim = 16  # output spatial feature dimension for decoder
            gnn = NequIP(
                output_irreps_str=output_irreps_str,
                message_cutoff=5.0,
                layers=3,
                n_radial_basis=8,
                irreps_str=irreps_str,
                energy_graph_readout_hidden_dims=(256, 256, 1),
                init_graph_readout_to_zero=True,
            )
            encoder = NumericEncoder(
                irreps_str=irreps_str,
                cutoff=5.0,
                num_radial_filters=33,  # 16 sin + 16 cos + 1 constant
                _quadrature_points_per_atom_scaling=12,  # conservative large scaling factor
            )

            if non_local_grid_feature_mode == 'local_only':
                decoder = None
            else:
                decoder = NumericDecoder(spatial_feature_dim=spatial_feature_dim)

            if 'reweighting' in non_local_grid_feature_mode:
                spatial_reweighting = SpatialReweighting(3, 16)
            else:
                spatial_reweighting = None

            if local_model is None:
                local_model = _get_learnable_local_functional_by_name('xcdiff')
            return EGXC(
                local_model=local_model,
                encoder=encoder,
                gnn=gnn,
                non_local_grid_feature_mode=non_local_grid_feature_mode,
                decoder=decoder,
                non_local_reweighting=spatial_reweighting,
                graph_readout=with_graph_readout,
            )
        case 'egxc2024_ex_reweighting':
            return _get_learnable_non_local_functional_by_name(
                'egxc2024',
                with_graph_readout=True,
                non_local_grid_feature_mode='local_only',
            )
        case 'egxc2024_ex_graph_readout':
            return _get_learnable_non_local_functional_by_name(
                'egxc2024',
                with_graph_readout=False,
                non_local_grid_feature_mode='reweighting_with_mGGA_feats',
            )
        case 'egxc2024_scan':
            return _get_learnable_non_local_functional_by_name(
                'egxc2024',
                non_local_grid_feature_mode=non_local_grid_feature_mode,
                with_graph_readout=with_graph_readout,
                local_model=_get_traditional_functional_by_name('scan'),
            )
        case 'egxc2024_nagai2020':
            return _get_learnable_non_local_functional_by_name(
                'egxc2024',
                with_graph_readout=with_graph_readout,
                non_local_grid_feature_mode='local_only',
                local_model=_get_learnable_local_functional_by_name('nagai2020_small'),
            )
        case 'egxc2024_b3lyp':
            return _get_learnable_non_local_functional_by_name(
                'egxc2024',
                with_graph_readout=True,
                non_local_grid_feature_mode='local_only',
                local_model=EnergyDensityToFunctionalWrapper(e_xc_b3lyp),
            )
        case _:
            raise ValueError(f'Unknown learnable non-local functional: {name}')


def _get_functional_by_name(
    name: str,
    spin_restricted: bool | None = None,
    use_density_fitting: bool | None = None,
    local_model: BaseEnergyFunctional | None = None,
    non_local_grid_feature_mode: NonLocalGridFeatureMode = 'reweighting_with_mGGA_feats',
    with_graph_readout: bool = True,
    local_n_layers: int | None = None,
    local_hidden_dim: int | None = None,
) -> BaseEnergyFunctional:
    """
    Get a functional by name with strict parameter handling.
    """

    # Dispatch
    traditional = {
        'lda',
        'vwn5',
        'pz81',
        'pw92_spin_restricted',
        'pw92',
        'pbe',
        'b88',
        'lyp',
        'local_b3lyp',
        'scan',
        'b3lyp',
        'pbe0',
        'wb97m-v',
    }
    _base_local_keys = [
        'nnmgga',
        'nagai2020',  # alias for nnmgga
        'nagai2020_orbital_free',
        'xcdiff',
        'xcdiff_orbital_free',
        'dick2021',  # alias for xcdiff
        'nagai2022',
        'skala_mgga',
    ]
    _size_suffixes = ['', '_xsmall', '_small', '_medium', '_large']
    learnable_local = {
        f'{base}{suffix}' for base in _base_local_keys for suffix in _size_suffixes
    }
    learnable_non_local = {
        'egxc2024',
        'egxc2024_ex_reweighting',
        'egxc2024_ex_graph_readout',
        'egxc2024_scan',
        'egxc2024_nagai2020',
        'egxc2024_b3lyp',
        'deixc_global',
    }

    if name in traditional:
        # Traditional functionals must not receive local dims
        if (local_hidden_dim is not None) or (local_n_layers is not None):
            raise ValueError(f'{name} does not accept local_hidden_dim/local_n_layers')
        return _get_traditional_functional_by_name(
            name,
            spin_restricted=spin_restricted,
            use_density_fitting=use_density_fitting,
        )
    elif name in learnable_local:
        return _get_learnable_local_functional_by_name(
            name,
            local_hidden_dim=local_hidden_dim,
            local_n_layers=local_n_layers,
        )
    elif name in learnable_non_local:
        # Non-local learnable functionals must not receive local dims here
        if (local_hidden_dim is not None) or (local_n_layers is not None):
            raise ValueError(f'{name} does not accept local_hidden_dim/local_n_layers')
        return _get_learnable_non_local_functional_by_name(
            name,
            local_model=local_model,
            with_graph_readout=with_graph_readout,
            non_local_grid_feature_mode=non_local_grid_feature_mode,
        )
    else:
        raise ValueError(f'Unknown functional: {name}')


def _build_custom_functional(kwargs: Dict[str, Any]) -> BaseEnergyFunctional:
    # Validate top-level keys to avoid silently ignoring parameters
    allowed_top = {'local', 'gnn', 'non_locality'}
    unknown_top = set(kwargs.keys()) - allowed_top
    if unknown_top:
        raise ValueError(f'Unknown/unused keys for custom functional: {unknown_top}')

    def _get_gnn(kwargs: Dict[str, Any]) -> BaseGNN:
        assert 'type' in kwargs['gnn']
        assert 'kwargs' in kwargs['gnn']
        key = kwargs['gnn']['type'].lower()
        match key:
            case 'painn':
                return PaiNN(**kwargs['gnn']['kwargs'])
            case 'nequip':
                return NequIP(**kwargs['gnn']['kwargs'])
            case 'nequix':
                return Nequix(**kwargs['gnn']['kwargs'])
            case _:
                raise ValueError(f'Unknown GNN type: {key}')

    def _get_encoder(kwargs: Dict[str, Any]) -> NumericEncoder:
        assert 'encoder' in kwargs['gnn']
        assert 'irreps_str' in kwargs['gnn']['kwargs']
        irreps_str = kwargs['gnn']['kwargs']['irreps_str']
        return NumericEncoder(irreps_str=irreps_str, **kwargs['gnn']['encoder'])

    def _get_decoder(kwargs: Dict[str, Any]) -> NumericDecoder | None:
        assert 'non_locality' in kwargs
        assert 'grid_feature_mode' in kwargs['non_locality']
        if kwargs['non_locality']['grid_feature_mode'] == 'local_only':
            return None

        assert 'decoder' in kwargs['non_locality']
        return NumericDecoder(**kwargs['non_locality']['decoder'])

    def _get_non_local_reweighting(kwargs: Dict[str, Any]) -> SpatialReweighting | None:
        if 'reweighting' not in kwargs['non_locality']['grid_feature_mode']:
            return None

        assert 'reweighting' in kwargs['non_locality']
        return SpatialReweighting(**kwargs['non_locality']['reweighting'])

    assert 'local' in kwargs and 'name' in kwargs['local']
    local_name = kwargs['local']['name'].lower()
    local_kwargs = kwargs['local']
    # Strictly map local kwargs to typed parameters; reject unknowns
    allowed_local = {
        'name',
        'local_n_layers',
        'local_hidden_dim',
    }
    unknown_local = set(local_kwargs.keys()) - allowed_local
    if unknown_local:
        raise ValueError(f'Unknown local DeiXC keys: {unknown_local}')
    if local_name == 'deixc':
        # Remove 'name' from local_kwargs without mutating the original dict
        _local_kwargs = local_kwargs.copy()
        _local_kwargs.pop('name', None)
        local_functional = DEIXC(**_local_kwargs)
    else:
        local_n_layers = local_kwargs.get('local_n_layers')
        local_hidden_dim = local_kwargs.get('local_hidden_dim')
        # For traditional names, dims are invalid and will be checked in _get_functional_by_name
        local_functional = _get_functional_by_name(
            local_name,
            local_n_layers=local_n_layers,
            local_hidden_dim=local_hidden_dim,
        )

    if kwargs.get('gnn', None) is not None:
        gnn = _get_gnn(kwargs)
        encoder = _get_encoder(kwargs)
        decoder = _get_decoder(kwargs)
        non_local_reweighting = _get_non_local_reweighting(kwargs)

        functional = EGXC(
            local_functional,
            encoder,
            gnn,
            kwargs['non_locality']['grid_feature_mode'],
            decoder,
            non_local_reweighting,
            kwargs['non_locality']['graph_readout'],
        )
        return functional
    return local_functional


def get_functional(name: str, **kwargs: Any) -> BaseEnergyFunctional:
    """
    Get a functional by name.
    """
    name = name.lower()
    if name == 'custom':
        return _build_custom_functional(kwargs)
    else:
        return _get_functional_by_name(name, **kwargs)
