import os
from typing import Iterable, Tuple

import numpy as onp
import pandas as pd
from openqdc import datasets as oqdc_datasets

from egxc.dataloading.datasets.base import (
    BaseDataset,
    RawSample,
    SupportsIndex,
    Targets,
    UnsplitDataset,
)
from egxc.dataloading.utils import IndexWrapper, random_index_split


def _max_period(atomic_numbers: Iterable) -> int:
    """
    Calculate the maximum period of the elements in the dataset.
    """
    rows = onp.array((2, 10, 18, 36, 54, 86, 118))
    max_z = max(atomic_numbers)
    return int((max_z > rows).sum() + 1)


class OQDCdataset(UnsplitDataset):
    data: pd.DataFrame
    directory: str

    def __init__(
        self,
        data_dir: str,
        energy_unit: str = 'hartree',
        distance_unit: str = 'ang',
        **kwargs,
    ):
        self.init_params(data_dir, energy_unit, distance_unit, **kwargs)
        self.init_dataframe()
        self.unique_elements = self.data['atomic_numbers'].explode().unique()  # type: ignore
        self.max_period = _max_period(self.unique_elements)

    def init_dataframe(self):
        raise NotImplementedError

    def random_split(
        self, train_fraction: float, val_fraction: float, seed: int
    ) -> Tuple[BaseDataset, BaseDataset, BaseDataset]:
        train_idx, val_idx, test_idx = random_index_split(
            len(self),
            (train_fraction, val_fraction, 1 - train_fraction - val_fraction),
            seed,
        )
        train = IndexWrapper(self, train_idx)
        val = IndexWrapper(self, val_idx)
        test = IndexWrapper(self, test_idx)
        return train, val, test


class DES370K(OQDCdataset):
    def init_dataframe(self):
        self.directory = os.path.join(self.directory, 'des370k')
        data = oqdc_datasets.DES370K(
            energy_unit=self.energy_unit,
            distance_unit=self.distance_unit,
            cache_dir=self.directory,
        )
        self.data = pd.DataFrame(data)

    def __getitem__(self, idx: SupportsIndex) -> RawSample:
        row = self.data.iloc[idx]  # type: ignore
        nuc_pos = row['positions']
        atom_z = row['atomic_numbers']
        nuc_pos, atom_z = self.sort_atoms(nuc_pos, atom_z)
        # TODO: add dimer splitting
        charge = sum(
            set(row['charges'])
        )  # TODO: check if this is correct (wierd charge format)
        number_of_electrons = sum(row['atomic_numbers']) - charge
        spin = number_of_electrons % 2
        # Note that des370k does not have 'total' energies the energies contained are interaction energies between dimers        total_energy = None  # TODO: force merge conflict with s2gnn branch
        total_energy = None  # Note to force merge conflict with s2gnn branch
        targets = Targets(total_energy, None, None)
        aux_data = self.load_aux(idx)
        return idx, (nuc_pos, atom_z, charge, spin, aux_data), targets


class QMugs(OQDCdataset):
    def init_dataframe(self):
        self.directory = os.path.join(self.directory, 'qmugs')
        data = oqdc_datasets.QMugs(
            energy_unit=self.energy_unit,
            distance_unit=self.distance_unit,
            cache_dir=self.directory,
        )
        self.data = pd.DataFrame(data)

    def __getitem__(self, idx: SupportsIndex) -> RawSample:
        row = self.data.iloc[idx]  # type: ignore
        nuc_pos = row['positions']
        atom_z = row['atomic_numbers']
        nuc_pos, atom_z = self.sort_atoms(nuc_pos, atom_z)
        # TODO: add dimer splitting
        charge = sum(
            set(row['charges'])
        )  # TODO: check if this is correct (wierd charge format)
        number_of_electrons = sum(row['atomic_numbers']) - charge
        spin = number_of_electrons % 2
        total_energy = row['energies'][1]  # DFT target
        nuc_forces = row['forces']
        targets = Targets(total_energy, nuc_forces, None)
        aux_data = self.load_aux(idx)
        return idx, (nuc_pos, atom_z, charge, spin, aux_data), targets


class SPICE(OQDCdataset):
    def init_dataframe(self):
        self.directory = os.path.join(self.directory, 'spice')
        data = oqdc_datasets.Spice(
            energy_unit=self.energy_unit,
            distance_unit=self.distance_unit,
            cache_dir=self.directory,
        )
        self.data = pd.DataFrame(data)

    def __getitem__(self, idx: SupportsIndex) -> RawSample:
        row = self.data.iloc[idx]  # type: ignore
        nuc_pos = row['positions']
        atom_z = row['atomic_numbers']
        nuc_forces = row['forces']
        total_energy = row['energies'].squeeze()
        nuc_pos, atom_z, nuc_forces = self.sort_atoms(nuc_pos, atom_z, nuc_forces)
        charge = sum(set(row['charges']))  # TODO: verify charge format
        number_of_electrons = sum(atom_z) - charge
        spin = number_of_electrons % 2

        targets = Targets(total_energy, nuc_forces, None)
        aux_data = self.load_aux(idx)
        return idx, (nuc_pos, atom_z, charge, spin, aux_data), targets
