import copy
import os
from typing import Literal, Tuple

import numpy as np
from rdkit import Chem
from tqdm import tqdm

from egxc.dataloading import (
    BaseDataset,
    PartiallySplitDataset,
    RawSample,
    SupportsIndex,
    Targets,
)
from egxc.dataloading.download import download_url, extract_zip
from egxc.dataloading.utils import IndexWrapper


class QM9(PartiallySplitDataset):
    raw_url1 = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip'  # energy already in hartree
    raw_url2 = 'https://figshare.com/ndownloader/files/3195404'
    # raw_url2 = 'https://figshare.com/files/3195404'
    raw_file_names = ['gdb9.sdf', 'gdb9.sdf.csv', 'uncharacterized.txt']

    def __init__(
        self,
        data_dir: str,
        heavy_atoms_thresh: int | Literal['debug', 'debug_larger'],
        exclude_fluorine: bool,
        energy_unit: str = 'hartree',
        distance_unit: str = 'ang',
        **kwargs,
    ):
        assert energy_unit == 'hartree', (
            'present implementation only supports hartree as energy unit'
        )
        assert distance_unit == 'ang', (
            'present implementation only supports angstrom as distance unit'
        )

        self.heavy_atoms_thresh = heavy_atoms_thresh
        self.exclude_fluorine = exclude_fluorine
        data_dir = os.path.join(data_dir, 'qm9')
        self.raw_dir = os.path.join(data_dir, 'raw')
        self.init_params(
            os.path.join(data_dir, 'processed'),
            energy_unit=energy_unit,
            distance_unit=distance_unit,
            **kwargs,
        )

        # TODO factor this logic out into BaseDataset?
        complete_file = os.path.join(self.raw_dir, 'complete.marker')
        if not os.path.exists(self.directory) or not os.path.exists(complete_file):
            self.download()
            self.process()
            with open(complete_file, 'w') as f:
                f.write('complete.marker')

        if exclude_fluorine:
            self.non_fluorine_idxs = np.load(
                os.path.join(self.directory, 'non_fluorine_idxs.npy')
            )
            self.unique_elements = {1, 6, 7, 8}  # H, C, N, O
        else:
            self.unique_elements = {1, 6, 7, 8, 9}  # H, C, N, O, F

        # TODO: heavy_atoms_thresh only takes effect if we use random_split()
        # the Dataset class itself will return all heavy atoms

    def __getitem__(self, idx: SupportsIndex) -> RawSample:
        if self.exclude_fluorine:
            idx = self.non_fluorine_idxs[idx]
        path = os.path.join(self.directory, 'samples', f'{idx}.npz')  # type: ignore
        data = np.load(path, allow_pickle=True)
        nuc_pos = data['nuc_pos']
        atom_z = data['atom_z']
        nuc_pos, atom_z = self.sort_atoms(nuc_pos, atom_z)
        targets = Targets(data['energy'], None, None)
        aux_data = self.load_aux(idx)
        return idx, (nuc_pos, atom_z, 0, 0, aux_data), targets

    def __len__(self) -> int:
        if self.exclude_fluorine:
            return self._len_without_fluorine()
        else:
            return self._len_with_fluorine()

    def _len_with_fluorine(self) -> int:
        """Returns the length of the dataset, including fluorine atoms."""
        return len(os.listdir(os.path.join(self.directory, 'samples')))

    def _len_without_fluorine(self) -> int:
        """Returns the length of the dataset, excluding fluorine atoms."""
        return len(self.non_fluorine_idxs)

    def download(self) -> None:
        file_path = download_url(self.raw_url1, self.raw_dir)
        extract_zip(file_path, self.raw_dir)
        os.unlink(file_path)

        download_url(self.raw_url2, self.raw_dir)
        os.rename(
            os.path.join(self.raw_dir, '3195404'),
            os.path.join(self.raw_dir, 'uncharacterized.txt'),
        )

    def process(self) -> None:
        with open(os.path.join(self.raw_dir, self.raw_file_names[2])) as f:
            skip = [int(x.split()[0]) - 1 for x in f.read().split('\n')[9:-2]]

        with open(os.path.join(self.raw_dir, self.raw_file_names[1])) as f:
            target = [
                [float(x) for x in line.split(',')[1:20]]
                for line in f.read().split('\n')[1:-1]
            ]
            y = np.asarray(target, dtype=np.float32)
            y = np.concat([y[:, 3:], y[:, :3]], axis=-1)

        atom_z = []
        nuc_pos = []
        energies = []

        suppl = Chem.SDMolSupplier(
            os.path.join(self.raw_dir, self.raw_file_names[0]),
            removeHs=False,
            sanitize=False,
        )

        for i, mol in enumerate(tqdm(suppl)):
            if i not in skip:
                atomic_numbers = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
                atom_z.append(atomic_numbers)

                pos = mol.GetConformer().GetPositions()
                pos = np.asarray(pos, dtype=np.float32)
                nuc_pos.append(pos)

                energies.append(y[i, 7])  # u0 column

        out_dir = os.path.join(self.directory, 'samples')
        os.makedirs(out_dir, exist_ok=False)

        heavy_atom_counts = [sum(z > 1 for z in zs) for zs in atom_z]
        non_fluorine_idxs = [
            i for i, zs in enumerate(atom_z) if not any(z == 9 for z in zs)
        ]

        np.save(
            os.path.join(self.directory, 'heavy_atom_counts.npy'),
            np.asarray(heavy_atom_counts),
        )  # type: ignore
        np.save(
            os.path.join(self.directory, 'non_fluorine_idxs.npy'),
            np.asarray(non_fluorine_idxs),
        )  # type: ignore

        for i in range(len(energies)):
            np.savez_compressed(
                os.path.join(out_dir, f'{i}.npz'),
                nuc_pos=nuc_pos[i],
                atom_z=atom_z[i],
                energy=energies[i],
            )

    def random_split(
        self, val_fraction: float, seed: int
    ) -> Tuple[BaseDataset, BaseDataset, BaseDataset]:
        all_ids = np.arange(self._len_with_fluorine())
        if self.exclude_fluorine:
            mask = np.zeros_like(all_ids, dtype=bool)
            mask[self.non_fluorine_idxs] = True
        else:
            mask = np.ones_like(all_ids, dtype=bool)

        print(
            f'{self.__class__.__name__} splitting dataset: {len(all_ids)} total samples, {len(all_ids[mask])} after fluorine filter (exclude_fluorine={self.exclude_fluorine})'
        )

        if self.heavy_atoms_thresh == 'debug':
            # minimal pipeline for debugging 2 molecules each
            train_set = IndexWrapper(self, [0, 1])
            val_set = IndexWrapper(self, [2, 3])
            test_set = IndexWrapper(self, [4, 5])
        elif self.heavy_atoms_thresh == 'debug_larger':
            # minimal pipeline for debugging 2 large molecules each
            train_set = IndexWrapper(self, [100, 101])
            val_set = IndexWrapper(self, [102, 103])
            test_set = IndexWrapper(self, [104, 105])
        else:
            heavy_atom_counts = np.load(
                os.path.join(self.directory, 'heavy_atom_counts.npy')
            )
            dev_mask = np.logical_and(mask, heavy_atom_counts <= self.heavy_atoms_thresh)
            test_mask = np.logical_and(mask, heavy_atom_counts > self.heavy_atoms_thresh)
            print(
                f'  heavy_atoms_thresh={self.heavy_atoms_thresh}: {dev_mask.sum()} train/val (<=), {test_mask.sum()} test (>)'
            )

            full_set = copy.deepcopy(self)
            full_set.exclude_fluorine = False
            dev_set = IndexWrapper(full_set, all_ids[dev_mask].tolist())
            test_set = IndexWrapper(full_set, all_ids[test_mask].tolist())

            n = len(dev_set)
            n_val = int(n * val_fraction)
            n_train = n - n_val
            indices = np.arange(n)
            indices = np.random.RandomState(seed).permutation(indices)
            train_idx = indices[:n_train].tolist()
            val_idx = indices[n_train:].tolist()
            train_set = IndexWrapper(dev_set, train_idx)
            val_set = IndexWrapper(dev_set, val_idx)

        return train_set, val_set, test_set
