from pathlib import Path
import os
import pickle, random
import os.path as osp

import numpy as np
import torch
from torch_geometric.data import Data, Dataset

from dataset.featurization import MoleculeFeaturizer
from utils.rotation import rotation_perturb_mol
# from etflow.commons.io import get_base_data_dir


class EuclideanDataset(Dataset):
    """Returns 3D Graph for different datasets

    Usage
    -----
    ```python
    from etflow.data import EuclideanDataset
    # pass path to processed data_dir
    dataset = EuclideanDataset(
        data_dir="processed",
        split="train",  # "train", "val", or "test"
        partition="drugs",  # "drugs" or "qm9"
    )
    ```
    """

    def __init__(
        self,
        data_dir: str = "data/qm9",
        split: str = "train",
        partition: str = "drugs",
        cascade: str = 'LocalGenerator',
        sigma_scale = 1,
    ):
        super().__init__()
        self.mol_feat = MoleculeFeaturizer()
        # Set split and partition
        self.split = split
        self.partition = partition
        self.data_folder = f'{data_dir}_{split}'
        # Find all data files for the specified partition and split
        self.data_files = sorted(os.listdir(self.data_folder))
        self.cascade = cascade
        self.sigma_scale = sigma_scale

        if len(self.data_files) == 0:
            raise ValueError(
                f"No data files found for partition {partition} and split {split}"
            )

    def len(self):
        return len(self.data_files)

    def get(self, idx):
        # Load the data file
        data_path = self.data_files[idx]
        with open(osp.join(self.data_folder, data_path ), 'rb') as f:
            data = pickle.load(f)
        # Get the molecule data
        smiles = data.smiles
        pos_confs = data.pos
        atomic_numbers = data.atomic_numbers
        # sample a random conformer
        conf_idx = np.random.randint(0, len(pos_confs))
        pos = pos_confs[conf_idx]

        # Featurize molecule
        node_attr = self.mol_feat.get_atom_features(smiles)
        # node_attr = data.node_attr
        chiral_index, chiral_nbr_index, chiral_tag = self.mol_feat.get_chiral_centers(
            smiles
        )
        edge_index, edge_attr = self.mol_feat.get_edge_index(smiles, False)
        mol = self.mol_feat.get_mol_with_conformer(smiles, pos)
        # Create a new graph with additional features
        return Data(
            sigma_scale= self.sigma_scale,
            pos = pos,
            atomic_numbers=atomic_numbers,
            smiles=smiles,
            edge_index=edge_index,
            chiral_index=chiral_index,
            chiral_nbr_index=chiral_nbr_index,
            chiral_tag=chiral_tag,
            mol=mol,
            node_attr=node_attr,
            edge_attr=edge_attr,
        )
    
class InferenceDataset(Dataset):
    """Returns 3D Graph for different datasets

    Usage
    -----
    ```python
    from etflow.data import EuclideanDataset
    # pass path to processed data_dir
    dataset = EuclideanDataset(
        data_dir="processed",
        split="train",  # "train", "val", or "test"
        partition="drugs",  # "drugs" or "qm9"
    )
    ```
    """

    def __init__(
        self,
        data_dir: str = "data",
        split: str = "train",
        partition: str = "qm9",
    ):
        super().__init__()
        self.mol_feat = MoleculeFeaturizer()
        # Set split and partition
        self.split = split
        self.partition = partition
        self.data_folder = f'{data_dir}/{partition}_{split}'
        # Find all data files for the specified partition and split
        self.data_files = sorted(os.listdir(self.data_folder))
        if len(self.data_files) == 0:
            raise ValueError(
                f"No data files found for partition {partition} and split {split}"
            )

    def len(self):
        return len(self.data_files)

    def get(self, idx):
        # Load the data file
        data_path = self.data_files[idx]
        with open(osp.join(self.data_folder, data_path ), 'rb') as f:
            data = pickle.load(f)
        # Get the molecule data
        smiles = data.smiles
        pos_confs = data.pos
        atomic_numbers = data.atomic_numbers
        # sample a random conformer
        conf_idx = np.random.randint(0, len(pos_confs))
        pos = pos_confs[conf_idx]

        # Featurize molecule
        node_attr = self.mol_feat.get_atom_features(smiles)
        # node_attr = data.node_attr
        chiral_index, chiral_nbr_index, chiral_tag = self.mol_feat.get_chiral_centers(
            smiles
        )
        edge_index, edge_attr = self.mol_feat.get_edge_index(smiles, False)
        mol = self.mol_feat.get_mol_with_conformer(smiles, pos)

        mask_edges = data.mask_edges
        mask_rotate = data.mask_rotate
        subgraph_batch = data.subgraph_batch

        # Create a new graph with additional features
        return Data(
            pos = torch.stack(pos_confs),
            atomic_numbers=atomic_numbers,
            smiles=smiles,
            edge_index=edge_index,
            chiral_index=chiral_index,
            chiral_nbr_index=chiral_nbr_index,
            chiral_tag=chiral_tag,
            mol=mol,
            node_attr=node_attr,
            edge_attr=edge_attr,
            mask_edges = mask_edges,
            mask_rotate = mask_rotate,
            subgraph_batch = subgraph_batch,
        )