import warnings
from math import isclose
from typing import Dict, Set

import jax.numpy as jnp
import pytest
from utils import ROOT_DATA_DIR, set_jax_testing_config

from deixc.dataset import DEIXCDataset
from egxc import dataloading
from egxc.dataloading import (
    BaseDataset,
    DatasetEnsemble,
    PartiallySplitDataset,
    PresplitDataset,
    Targets,
    UnsplitDataset,
    get_preload_transform,
    get_psys_and_dataloaders,
)
from egxc.dataloading.utils import random_index_split
from egxc.discretization import get_grid_fn, get_gto_grid_eval_fn, get_gto_preloader
from egxc.systems import System, examples
from egxc.systems.preload import Alignment, PreloadSystem
from egxc.utils import constants

set_jax_testing_config()
pytestmark = pytest.mark.data


@pytest.mark.quick
@pytest.mark.parametrize(
    'align',
    (Alignment(), Alignment(4, 8, 512)),
    ids=('without padding', 'with padding'),
)
def test_jax_transform(align: Alignment):
    ethanol = examples.get_preloaded('ethanol', 'sto-6g', alignment=1)
    l_max, vec_basis_fn_factory = get_gto_preloader('sto-6g', set(ethanol.atom_z))
    grid_fn = get_grid_fn(1, set(ethanol.atom_z), align.grid)
    basis_fn = get_gto_grid_eval_fn(deriv=1, max_angular_momentum=l_max)
    transform = dataloading.get_jax_transform(grid_fn, basis_fn)
    pvec_basis_fns = vec_basis_fn_factory(ethanol.atom_z)
    density_matrices, sys = transform(ethanol, pvec_basis_fns)
    P0 = density_matrices[0]
    assert isinstance(sys, System)
    assert isinstance(density_matrices, tuple)
    assert isinstance(P0, jnp.ndarray)


presplit_datasets: Dict[str, dataloading.PresplitDataset] = {}

KWARGS = {
    'initial_ref_density_method_key': None,
    'initial_ref_density_method_kwargs': None,
}

partially_split_datasets: Dict[str, dataloading.PartiallySplitDataset] = {
    'md17': dataloading.MD17(ROOT_DATA_DIR, 'ethanol', train=True, **KWARGS),  # type: ignore
    'qm9': dataloading.QM9(ROOT_DATA_DIR, 7, False, **KWARGS),  # type: ignore
    '3bpa': dataloading.ThreeBPA(ROOT_DATA_DIR, **KWARGS),  # type: ignore
}

unsplit_datasets: Dict[str, dataloading.UnsplitDataset] = {
    'des370k': dataloading.DES370K(ROOT_DATA_DIR, **KWARGS),  # type: ignore
    'qmugs': dataloading.QMugs(ROOT_DATA_DIR, **KWARGS),  # type: ignore
    'spice': dataloading.SPICE(ROOT_DATA_DIR, **KWARGS),  # type: ignore
}

datasets = presplit_datasets | partially_split_datasets | unsplit_datasets


Dataset = PresplitDataset | PartiallySplitDataset | UnsplitDataset


@pytest.mark.quick
@pytest.mark.parametrize('dataset', datasets.values(), ids=datasets.keys())
def test_dataset_init(dataset: Dataset):
    assert isinstance(dataset, Dataset)
    assert hasattr(dataset, 'energy_unit'), (
        f'{dataset.__class__.__name__} does not have energy_unit'
    )
    assert hasattr(dataset, 'distance_unit'), (
        f'{dataset.__class__.__name__} does not have distance_unit'
    )
    assert hasattr(dataset, 'unique_elements')
    assert hasattr(dataset, 'initial_ref_density_method_key')
    assert hasattr(dataset, 'initial_ref_density_method_kwargs')


@pytest.mark.quick
@pytest.mark.parametrize('dataset', datasets.values(), ids=datasets.keys())
def test__get_item__(dataset: Dataset):
    idx, raw_inputs, targets = dataset[0]
    time = dataset.timed_get_item(1)
    assert time < 0.4, f'It took {time} seconds to get an item from the dataset.'
    assert isinstance(raw_inputs, tuple)
    assert len(raw_inputs) == 5
    assert isinstance(targets, Targets)


@pytest.mark.quick
@pytest.mark.parametrize(
    'dataset',
    (partially_split_datasets | unsplit_datasets).values(),
    ids=(partially_split_datasets | unsplit_datasets).keys(),
)
def test_dataset_split(dataset: UnsplitDataset):
    if isinstance(dataset, UnsplitDataset):
        train, val, test = dataset.random_split(0.7, 0.1, seed=42)
        assert isclose(len(train), 0.7 * len(dataset), rel_tol=0.001, abs_tol=0.01)
        assert isclose(len(val), 0.1 * len(dataset), rel_tol=0.001, abs_tol=0.01)
        assert isclose(len(test), 0.2 * len(dataset), rel_tol=0.001, abs_tol=0.01)
    else:
        train, val, test = dataset.random_split(0.1, seed=42)

    # test data ensemble
    if isinstance(dataset, UnsplitDataset):
        datasets = DatasetEnsemble.from_random_split(dataset, 0.7, 0.1, 42)
    else:
        datasets = DatasetEnsemble.from_partial_random_split(dataset, 0.1, 42)
    idx, raw_inputs, targets = next(iter(datasets.train))
    assert isinstance(raw_inputs, tuple)
    assert len(raw_inputs) == 5
    assert isinstance(targets, Targets)
    idx, raw_inputs, targets = next(iter(datasets.val))
    assert isinstance(raw_inputs, tuple)
    assert len(raw_inputs) == 5
    assert isinstance(targets, Targets)
    idx, raw_inputs, targets = next(iter(datasets.test))
    assert isinstance(raw_inputs, tuple)
    assert len(raw_inputs) == 5
    assert isinstance(targets, Targets)


@pytest.mark.quick
def test_md17():
    dataset = partially_split_datasets['md17']
    train, val, test = dataset.random_split(0.1, seed=42)
    assert len(train) == 900
    assert len(val) == 100
    assert len(test) == 1000


@pytest.mark.quick
def test_qm9():
    dataset = partially_split_datasets['qm9']
    assert len(dataset) == 130831
    idx, sample, targets = dataset[0]
    assert jnp.isclose(targets.energy, -40.47893)  # type: ignore

    train, val, test = dataset.random_split(0.1, seed=42)
    assert len(train) > 0
    assert len(val) > 0
    assert len(test) > 0
    assert len(dataset) == len(train) + len(val) + len(test)

    len_dev = len(dataset) - len(test)
    len_val = int(0.1 * len_dev)
    assert len(val) == len_val
    assert len(train) == len_dev - len_val


@pytest.mark.quick
def test_qm9_exclude_fluorine():
    dataset = dataloading.QM9(ROOT_DATA_DIR, 7, exclude_fluorine=True, **KWARGS)  # type: ignore
    assert len(dataset) == 128908
    print(len(dataset))

    train, val, test = dataset.random_split(0.1, seed=42)

    len_dev = len(dataset) - len(test)
    len_val = int(0.1 * len_dev)
    assert len(val) == len_val
    assert len(train) == len_dev - len_val


@pytest.mark.quick
def test_3bpa():
    dataset = partially_split_datasets['3bpa']
    assert len(dataset) == 13993

    _, _, targets = dataset[1000]
    assert jnp.isclose(targets.energy, -17660.06855946407 * constants.EV_TO_HATREE)  # type: ignore

    train, val, test = dataset.random_split(
        0.1,
        seed=42,
        train_subsplits=['train_300K'],  # type: ignore
        test_subsplits=['test_300K', 'test_dih_beta120'],  # type: ignore
    )
    assert len(train) == 450
    assert len(val) == 50
    assert len(test) == 4016


@pytest.mark.quick
def test_qm40():
    """
    QM40 should load local SDF structures via RDKit and provide the standard RawSample
    tuple used throughout the codebase.

    Expected layout under ROOT_DATA_DIR:
      qm40/structures/<MOL_ID>/conf_00.sdf (+ conf_01.sdf, conf_02.sdf)

    Conformer selection is controlled via dataset kwarg `conformer`:
      0 | 1 | 2: pick a fixed conformer per molecule
      'all': enumerate all three conformers (dataset length ×3)
    """
    dataset = dataloading.QM40(ROOT_DATA_DIR, conformer=0, **KWARGS)  # type: ignore
    assert len(dataset) > 0
    idx, raw_inputs, targets = dataset[0]
    nuc_pos, atom_z, charge, spin, aux = raw_inputs
    assert nuc_pos.shape[0] == len(atom_z)
    assert nuc_pos.shape[1] == 3
    assert isinstance(charge, int)
    assert isinstance(spin, int)


l_max, pvec_basis_fn_factory = get_gto_preloader('sto-3g', {1, 6, 7, 8, 9, 16, 17, 35})

transformation_cases = {
    'restricted+unaligned+exact': (True, Alignment(), False),
    'unrestricted+aligned+df': (False, Alignment(4, 8, 512), True),
}


@pytest.mark.quick
@pytest.mark.parametrize('dataset', datasets.values(), ids=datasets.keys())
@pytest.mark.parametrize(
    'spin_restricted,alignment,use_density_fitting',
    transformation_cases.values(),
    ids=transformation_cases.keys(),
)
def test_datalpoader(
    dataset: Dataset,
    spin_restricted: bool,
    alignment: Alignment,
    use_density_fitting: bool,
):
    if isinstance(dataset, dataloading.QMugs) and use_density_fitting == False:
        pytest.skip('exact ERI tensors are too large for QMugs')
    transformations = get_preload_transform(
        1,
        'sto-3g',
        spin_restricted,
        alignment,
        use_density_fitting,
        'minao',
        False,
        pvec_basis_fn_factory,
    )
    if isinstance(dataset, PresplitDataset):
        datasets = DatasetEnsemble.from_presplit_dataset(dataset)
    elif isinstance(dataset, PartiallySplitDataset):
        datasets = DatasetEnsemble.from_partial_random_split(dataset, 0.1, 42)
    else:
        assert isinstance(dataset, UnsplitDataset)
        datasets = DatasetEnsemble.from_random_split(dataset, 0.7, 0.1, 42)
    print(next(iter(datasets.train)))
    _, dataloaders = get_psys_and_dataloaders(datasets, transformations, True, 0, 1, 42)
    psys, pvec_basis_fns, targets = next(iter(dataloaders.train))
    assert isinstance(psys, PreloadSystem)
    assert isinstance(targets, Targets)
    psys, pvec_basis_fns, targets = next(iter(dataloaders.val))
    assert isinstance(psys, PreloadSystem)
    assert isinstance(targets, Targets)
    psys, pvec_basis_fns, targets = next(iter(dataloaders.test))
    assert isinstance(psys, PreloadSystem)
    assert isinstance(targets, Targets)


@pytest.mark.slow
@pytest.mark.parametrize('dataset', datasets.values(), ids=datasets.keys())
def test_dataset_split_no_leakage(dataset: Dataset):
    if isinstance(dataset, UnsplitDataset):
        train, val, test = dataset.random_split(0.7, 0.1, seed=42)
    elif isinstance(dataset, PartiallySplitDataset):
        train, val, test = dataset.random_split(0.1, seed=42)
    else:  # PresplitDataset
        assert isinstance(dataset, PresplitDataset)
        train, val, test = dataset.split()

    if not isinstance(dataset, dataloading.ThreeBPA):
        # ThreeBPA has a different handling of indices due to its subsplits
        def collect_indices(ds: BaseDataset) -> Set:
            return {ds[i][0] for i in range(len(ds))}

        train_set = collect_indices(train)
        val_set = collect_indices(val)
        test_set = collect_indices(test)

        assert len(train_set) == len(train)
        assert len(val_set) == len(val)
        assert len(test_set) == len(test)

        assert train_set.isdisjoint(val_set)
        if not isinstance(dataset, dataloading.MD17):
            # MD17's `train` and `test` sets are disjoint but there enumeration is separate
            assert train_set.isdisjoint(test_set)
            assert val_set.isdisjoint(test_set)

            assert len(train_set | val_set | test_set) == len(train) + len(val) + len(
                test
            )


@pytest.mark.quick
def test_random_index_split_no_leakage():
    train_idx, val_idx, test_idx = random_index_split(100, (0.8, 0.1, 0.1), 42)

    train_set = set(train_idx)
    val_set = set(val_idx)
    test_set = set(test_idx)

    assert len(train_set) == len(train_idx)
    assert len(val_set) == len(val_idx)
    assert len(test_set) == len(test_idx)

    assert train_set.isdisjoint(val_set)
    assert train_set.isdisjoint(test_set)
    assert val_set.isdisjoint(test_set)

    assert train_set | val_set | test_set == set(range(100))


@pytest.mark.quick
@pytest.mark.parametrize('dataset', datasets.values(), ids=datasets.keys())
def test_dei_xc_dataset_loading(dataset: BaseDataset):
    try:
        dei_ds = DEIXCDataset(
            dataset,
            method_key='ks_dft',
            method_kwargs={
                'xc_str': 'wB97M-V',
                'basis': 'def2-SVP',
                'backend': 'pyscf',
                'use_eri_density_fitting': True,
                'use_exchange_density_fitting': True,
                'quadrature_grid_level': 1,
                'spin_restricted': True,
            },
            align_scf_trajectory=5,
            shift_dispersion=False,
        )
        idx, raw_input, targets = dei_ds[0]
        assert targets.total_energies is not None
        assert targets.density_matrix is not None
        assert targets.xc_potential_matrix is not None
        assert targets.xc_potential_matrices.shape[0] == 5  # Check SCF trajectory length
    except FileNotFoundError:
        warnings.warn(
            f'DEIXCDataset could not be initialized for {dataset.__class__.__name__}. '
            'This is expected if the reference values have not been computed yet.'
        )
