import os
from typing import Iterable, Tuple

import numpy as np
from tqdm.auto import tqdm

from egxc.dataloading import RawSample, SupportsIndex, Targets, UnsplitDataset
from egxc.dataloading.datasets.base import BaseDataset
from egxc.dataloading.utils import IndexWrapper

# Element symbol to atomic number mapping
ELEMENT_TO_Z = {
    'H': 1,
    'C': 6,
    'N': 7,
    'O': 8,
    'F': 9,
    'S': 16,
    'Cl': 17,
}

# QM40 contains molecules with elements: H, C, N, O, F, S, Cl
QM40_ELEMENTS = {1, 6, 7, 8, 9, 16, 17}


class QM40(UnsplitDataset):
    """
    QM40 dataset backed by CSV files from the official release.
    https://doi.org/10.6084/m9.figshare.25993060

    Expects the following files in `<data_dir>/qm40/raw/`:
    - QM40_main.csv: molecule-level properties (Zinc_id, smile, energies, etc.)
    - QM40_xyz.csv: atomic coordinates (Zinc_id, atom, final_x/y/z, charge)

    Processing writes per-molecule samples under:
      <data_dir>/qm40/processed/samples/<idx>.npz

    QM40 contains optimized geometries only (no conformers).
    Maximum heavy atoms per molecule: 40.

    Subsampling:
    - samples_per_heavy_atom_bin: if set, limits to N molecules per heavy atom count
    - max_heavy_atoms: if set, excludes molecules with more than this many heavy atoms

    Examples
    --------
    ```python
    # Basic usage
    dataset = QM40(data_dir)

    # With subsampling: 100 molecules per heavy atom count bin
    dataset = QM40(data_dir, samples_per_heavy_atom_bin=100)

    # Limit to molecules with at most 30 heavy atoms
    dataset = QM40(data_dir, max_heavy_atoms=30)

    # Split the dataset
    train, val, test = dataset.random_split(train_fraction=0.8, val_fraction=0.1, seed=42)
    ```

    """

    def __init__(
        self,
        data_dir: str,
        samples_per_heavy_atom_bin: int | None = None,
        exclude_elements: Iterable[int] | None = None,
        max_heavy_atoms: int | None = None,
        show_progress: bool = True,
        energy_unit: str = 'hartree',
        distance_unit: str = 'ang',
        **kwargs,
    ):
        assert energy_unit == 'hartree', (
            'present implementation only supports hartree as energy unit'
        )
        assert distance_unit == 'ang', (
            'present implementation only supports angstrom as distance unit'
        )

        self.samples_per_heavy_atom_bin = samples_per_heavy_atom_bin
        self.exclude_elements = (
            set(int(z) for z in exclude_elements)
            if exclude_elements is not None
            else set()
        )
        self.max_heavy_atoms = max_heavy_atoms
        self.show_progress = show_progress

        data_dir = os.path.join(data_dir, 'qm40')
        self.raw_dir = os.path.join(data_dir, 'raw')

        self.init_params(
            os.path.join(data_dir, 'processed'),
            energy_unit=energy_unit,
            distance_unit=distance_unit,
            **kwargs,
        )

        # Check for required CSV files
        self.xyz_csv_path = os.path.join(self.raw_dir, 'QM40_xyz.csv')
        self.main_csv_path = os.path.join(self.raw_dir, 'QM40_main.csv')

        # Preprocessing is considered complete when the required processed artifacts exist.
        complete_file = os.path.join(self.raw_dir, 'complete.marker')
        samples_dir = os.path.join(self.directory, 'samples')
        metadata_path = os.path.join(self.directory, 'metadata.npz')
        heavy_atom_counts_path = os.path.join(self.directory, 'heavy_atom_counts.npy')

        processed_complete = (
            os.path.isdir(samples_dir)
            and os.path.isfile(metadata_path)
            and os.path.isfile(heavy_atom_counts_path)
        )
        if not processed_complete:
            self.process()
            with open(complete_file, 'w') as f:
                f.write('complete.marker')
        elif not os.path.exists(complete_file):
            with open(complete_file, 'w') as f:
                f.write('complete.marker')

        meta = np.load(os.path.join(self.directory, 'metadata.npz'), allow_pickle=False)
        self.sample_mol_ids = meta['sample_mol_ids'].astype(str)

        # Validate exclude_elements against known QM40 elements
        if self.exclude_elements:
            invalid = self.exclude_elements.difference(QM40_ELEMENTS)
            if invalid:
                raise ValueError(
                    f'QM40 exclude_elements contains elements not in QM40: {sorted(invalid)}. '
                    f'QM40 only contains: {sorted(QM40_ELEMENTS)}'
                )
        self.unique_elements = QM40_ELEMENTS.difference(self.exclude_elements)

        self.sample_z_mask: np.ndarray | None = None
        if 'sample_z_mask' in meta.files:
            self.sample_z_mask = meta['sample_z_mask'].astype(np.uint64)

        self.heavy_atom_counts = np.load(
            os.path.join(self.directory, 'heavy_atom_counts.npy')
        )

        # Start with all samples
        self.sample_indices = np.arange(len(self.sample_mol_ids), dtype=np.int64)

        # Filter by elements BEFORE subsampling.
        if self.exclude_elements:
            common_exclude = {9, 16, 17}  # F, S, Cl (non-CHNO elements in QM40)
            filters_path = os.path.join(self.directory, 'filters.npz')
            element_filters_path = os.path.join(self.directory, 'element_filters.npz')

            if self.exclude_elements == common_exclude and os.path.isfile(filters_path):
                if self.show_progress:
                    print('  Using precomputed common filter (fastest)', flush=True)
                filters = np.load(filters_path, allow_pickle=False)
                self.sample_indices = filters['no_F_S_Cl'].astype(np.int64)
            elif os.path.isfile(element_filters_path):
                if self.show_progress:
                    print('  Using per-element filters (fast)', flush=True)
                ef = np.load(element_filters_path, allow_pickle=False)

                for z in self.exclude_elements:
                    key = f'no_z{int(z)}'
                    if key not in ef.files:
                        raise KeyError(
                            f'Missing precomputed filter "{key}" in {element_filters_path}. '
                            f'Re-run QM40.process() to regenerate filters.'
                        )

                n_samples = len(self.sample_mol_ids)
                keep_mask = np.ones(n_samples, dtype=bool)
                for z in self.exclude_elements:
                    keep_idx = ef[f'no_z{int(z)}']
                    temp_mask = np.zeros(n_samples, dtype=bool)
                    temp_mask[keep_idx] = True
                    keep_mask &= temp_mask
                self.sample_indices = np.where(keep_mask)[0].astype(np.int64)
            elif self.sample_z_mask is not None:
                if self.show_progress:
                    print('  Using bitmask filter (fast)', flush=True)
                exclude_mask = np.uint64(0)
                for z in self.exclude_elements:
                    if z < 0 or z >= 64:
                        raise ValueError(
                            f'QM40 exclude_elements only supports 0 <= Z < 64, got Z={z}'
                        )
                    exclude_mask |= np.uint64(1) << np.uint64(z)
                masks = self.sample_z_mask[self.sample_indices]
                keep = (masks & exclude_mask) == 0
                self.sample_indices = self.sample_indices[keep]
            else:
                raise RuntimeError(
                    'No element filtering method available. Re-run process().'
                )

        # Filter by maximum heavy atoms if requested
        if self.max_heavy_atoms is not None:
            counts = self.heavy_atom_counts[self.sample_indices]
            keep = counts <= self.max_heavy_atoms
            self.sample_indices = self.sample_indices[keep]

        # Apply subsampling if requested
        if self.samples_per_heavy_atom_bin is not None:
            self.sample_indices = self._subsample_by_heavy_atoms(
                self.sample_indices, self.samples_per_heavy_atom_bin
            )

    def _subsample_by_heavy_atoms(
        self, indices: np.ndarray, n_per_bin: int, seed: int = 0
    ) -> np.ndarray:
        """
        Subsample indices to have at most n_per_bin molecules per heavy atom count.
        """
        rng = np.random.RandomState(seed)
        counts = self.heavy_atom_counts[indices]

        unique_counts = np.unique(counts)
        keep_mask = np.zeros(len(indices), dtype=bool)

        for hac in unique_counts:
            bin_mask = counts == hac
            bin_indices = np.where(bin_mask)[0]

            if len(bin_indices) <= n_per_bin:
                keep_mask[bin_mask] = True
            else:
                selected = rng.choice(bin_indices, size=n_per_bin, replace=False)
                keep_mask[selected] = True

        return indices[keep_mask]

    def __len__(self) -> int:
        return int(len(self.sample_indices))

    def __getitem__(self, idx: SupportsIndex) -> RawSample:
        sample_id = int(self.sample_indices[int(idx)])

        path = os.path.join(self.directory, 'samples', f'{sample_id}.npz')
        data = np.load(path, allow_pickle=False)
        nuc_pos = data['nuc_pos']
        atom_z = data['atom_z']
        charge = int(data['charge'])
        spin = int(data['spin'])
        nuc_pos, atom_z = self.sort_atoms(nuc_pos, atom_z)
        targets = Targets(None, None, None)
        aux_data = self.load_aux(sample_id)
        return sample_id, (nuc_pos, atom_z, charge, spin, aux_data), targets

    def random_split(
        self, train_fraction: float, val_fraction: float, seed: int
    ) -> Tuple[BaseDataset, BaseDataset, BaseDataset]:
        """
        Split the dataset into train, val, and test sets.
        """
        test_fraction = 1.0 - train_fraction - val_fraction
        if test_fraction < 0:
            raise ValueError(
                f'train_fraction + val_fraction must be <= 1.0, '
                f'got {train_fraction} + {val_fraction} = {train_fraction + val_fraction}'
            )

        n = len(self.sample_indices)
        n_train = int(n * train_fraction)
        n_val = int(n * val_fraction)

        rng = np.random.RandomState(seed)
        perm = rng.permutation(n)

        train_idx = perm[:n_train].tolist()
        val_idx = perm[n_train : n_train + n_val].tolist()
        test_idx = perm[n_train + n_val :].tolist()

        return (
            IndexWrapper(self, train_idx),
            IndexWrapper(self, val_idx),
            IndexWrapper(self, test_idx),
        )

    def process(self) -> None:
        print('Processing QM40 dataset from CSV files', flush=True)

        if not os.path.isfile(self.xyz_csv_path):
            raise FileNotFoundError(f'QM40_xyz.csv not found: {self.xyz_csv_path}')
        if not os.path.isfile(self.main_csv_path):
            raise FileNotFoundError(f'QM40_main.csv not found: {self.main_csv_path}')

        metadata_path = os.path.join(self.directory, 'metadata.npz')
        heavy_atom_counts_path = os.path.join(self.directory, 'heavy_atom_counts.npy')
        element_filters_path = os.path.join(self.directory, 'element_filters.npz')
        filters_path = os.path.join(self.directory, 'filters.npz')

        os.makedirs(self.directory, exist_ok=True)
        out_dir = os.path.join(self.directory, 'samples')
        os.makedirs(out_dir, exist_ok=True)

        # Read molecule IDs from main CSV to get the ordering
        print('Reading QM40_main.csv for molecule IDs...', flush=True)
        mol_ids: list[str] = []
        with open(self.main_csv_path, 'r') as f:
            f.readline()  # skip header
            for line in f:
                zinc_id = line.split(',')[0]
                mol_ids.append(zinc_id)
        print(f'  Found {len(mol_ids)} molecules', flush=True)

        # Parse xyz CSV and group atoms by molecule
        print('Parsing QM40_xyz.csv...', flush=True)
        # Structure: Zinc_id,smile,atom,init_x,init_y,init_z,final_x,final_y,final_z,charge

        # Pre-allocate storage for each molecule
        mol_atoms: dict[str, list[tuple[str, float, float, float]]] = {
            mol_id: [] for mol_id in mol_ids
        }

        with open(self.xyz_csv_path, 'r') as f:
            f.readline()  # skip header
            line_count = 0
            for line in tqdm(
                f,
                desc='reading xyz',
                disable=not self.show_progress,
                unit='atoms',
                mininterval=1.0,
            ):
                parts = line.strip().split(',')
                zinc_id = parts[0]
                atom_symbol = parts[2]
                # Use final (optimized) coordinates
                final_x = float(parts[6])
                final_y = float(parts[7])
                final_z = float(parts[8])

                if zinc_id in mol_atoms:
                    mol_atoms[zinc_id].append((atom_symbol, final_x, final_y, final_z))
                line_count += 1

        print(f'  Parsed {line_count} atom rows', flush=True)

        # Process each molecule and write sample files
        print('Writing sample files...', flush=True)
        unique_elements: set[int] = set()
        heavy_atom_counts: list[int] = []
        sample_z_masks: list[np.uint64] = []

        for sample_id, mol_id in enumerate(
            tqdm(mol_ids, desc='processing', disable=not self.show_progress)
        ):
            atoms = mol_atoms[mol_id]
            if len(atoms) == 0:
                raise ValueError(f'No atoms found for molecule {mol_id}')

            # Convert to arrays
            atom_z = np.array(
                [ELEMENT_TO_Z[sym] for sym, x, y, z in atoms], dtype=np.uint8
            )
            nuc_pos = np.array([[x, y, z] for sym, x, y, z in atoms], dtype=np.float32)

            # Compute heavy atom count (Z > 1)
            heavy_count = int(np.sum(atom_z > 1))
            heavy_atom_counts.append(heavy_count)

            # Update unique elements
            unique_elements.update(int(z) for z in atom_z.tolist())

            # Build element bitmask for fast filtering
            z_mask = np.uint64(0)
            for z in np.unique(atom_z).tolist():
                zi = int(z)
                if 0 <= zi < 64:
                    z_mask |= np.uint64(1) << np.uint64(zi)
            sample_z_masks.append(z_mask)

            # QM40 molecules are neutral with closed-shell
            charge = 0
            n_electrons = int(atom_z.sum()) - charge
            spin = n_electrons % 2

            # Write sample file
            sample_path = os.path.join(out_dir, f'{sample_id}.npz')
            np.savez_compressed(
                sample_path,
                nuc_pos=nuc_pos,
                atom_z=atom_z,
                charge=np.asarray(charge, dtype=np.int16),
                spin=np.asarray(spin, dtype=np.int8),
            )

        # Save metadata
        print('Saving metadata...', flush=True)
        np.savez(
            metadata_path,
            sample_mol_ids=np.asarray(mol_ids, dtype=str),
            unique_elements=np.asarray(sorted(unique_elements), dtype=np.uint8),
            sample_z_mask=np.asarray(sample_z_masks, dtype=np.uint64),
        )
        np.save(
            heavy_atom_counts_path,
            np.asarray(heavy_atom_counts, dtype=np.int32),
        )

        # Build element filters
        print('Building element filters...', flush=True)
        szm = np.asarray(sample_z_masks, dtype=np.uint64)
        element_filters: dict[str, np.ndarray] = {}
        for z in sorted(unique_elements):
            z = int(z)
            bit = np.uint64(1) << np.uint64(z)
            keep_idx = np.where((szm & bit) == 0)[0].astype(np.int64)
            element_filters[f'no_z{z}'] = keep_idx
        np.savez(element_filters_path, **element_filters)  # type: ignore

        # Build common filter (F, S, Cl - non-CHNO elements in QM40)
        print('Building common element filter (excluding F, S, Cl)...', flush=True)
        common_exclude_mask = (
            (np.uint64(1) << np.uint64(9))  # F
            | (np.uint64(1) << np.uint64(16))  # S
            | (np.uint64(1) << np.uint64(17))  # Cl
        )
        keep_common = np.where((szm & common_exclude_mask) == 0)[0].astype(np.int64)
        np.savez(filters_path, no_F_S_Cl=keep_common)

        print('QM40 processing done', flush=True)
        print(f'  Total molecules: {len(mol_ids)}', flush=True)
        print(f'  Unique elements: {sorted(unique_elements)}', flush=True)
        print(
            f'  Heavy atom range: {min(heavy_atom_counts)} - {max(heavy_atom_counts)}',
            flush=True,
        )
