import os
import sys
import numpy as np
import tqdm
import pickle
import torch
import rdkit.Chem as Chem

from typing import List, Tuple, Dict, Any
from torch.utils.data import random_split, Dataset
from rdkit.Chem import AllChem
# from rdkit.Chem.rdchem.Mol import Mol as Molecule

from data.config import *
from data.structures import PackedMolGraph
from data.utils import split_by_interval, get_mean_std
from data.phi_psi import get_phi, get_psi
from data.geom_qm9.load_qm9 import GENERATE_RATE, load_qm9
from data.geom_drugs.load_drugs import GENERATE_RATE, load_drugs
from data.load_data import dft_mol_positions as get_positions
from data.load_data import rdkit_mol_positions as rdkit_positions
from data.load_data import GeoMolDataset

perr = print


class MultiGeoMolDataset(Dataset):
    def __init__(self,
                 list_packed_mol_graph: List[PackedMolGraph],
                 list_smiles: List[str],
                 list_list_target_conf: List[List[torch.FloatTensor]],
                 list_list_rdkit_conf: List[List[torch.FloatTensor]],
                 use_cuda=False,
                 **kwargs):
        super(MultiGeoMolDataset, self).__init__()
        self.n_pack = len(list_packed_mol_graph)
        assert len(list_smiles) == self.n_pack
        assert len(list_list_target_conf) == self.n_pack
        assert len(list_list_rdkit_conf) == self.n_pack

        self.list_packed_mol_graph = list_packed_mol_graph
        self.list_smiles = list_smiles
        self.list_extra_dict: List[Dict[str, Any]] = [{} for _ in range(self.n_pack)]
        self.list_list_target_conf = list_list_target_conf
        self.list_list_rdkit_conf = list_list_rdkit_conf

        perr(f'\t\tPacking dataset...')
        if 'phi' in kwargs.keys() and kwargs['phi']:
            for i in range(self.n_pack):
                list_pos = self.list_list_target_conf[i]
                zips = zip(*[get_phi(self.list_packed_mol_graph[i].mask_matrices,
                                     list_pos[j], use_cuda=False) for j in range(len(list_pos))])
                self.list_extra_dict[i]['phis_w1'], \
                self.list_extra_dict[i]['phis_w2'], \
                self.list_extra_dict[i]['phis_flat'], \
                self.list_extra_dict[i]['phis_g'] = zips

        if 'psi' in kwargs.keys() and kwargs['psi']:
            for i in range(self.n_pack):
                list_pos = self.list_list_target_conf[i]
                zips = zip(*[get_psi(self.list_packed_mol_graph[i].mask_matrices,
                                     list_pos[j], use_cuda=False) for j in range(len(list_pos))])
                self.list_extra_dict[i]['psis_w1'], \
                self.list_extra_dict[i]['psis_w2'], \
                self.list_extra_dict[i]['psis_flat'], \
                self.list_extra_dict[i]['psis_g'] = zips

    def __getitem__(self, index) -> Tuple[PackedMolGraph, str, List[torch.Tensor],
                                          List[torch.Tensor], Dict[str, Any]]:
        return self.list_packed_mol_graph[index], self.list_smiles[index], self.list_list_target_conf[index], \
               self.list_list_rdkit_conf[index], self.list_extra_dict[index]

    def __len__(self):
        return self.n_pack


class SupportedMultiDatasets:
    GEOM_DRUGS = 'geom_drugs'
    GEOM_DRUGS_SMALL = 'geom_drugs-small'
    GEOM_QM9 = 'geom_qm9'
    GEOM_QM9_SMALL = 'geom_qm9-small'

    @staticmethod
    def tolist():
        return ['geom_drugs', 'geom_drugs-small', 'geom_qm9', 'geom_qm9-small']


def load_multi_data(dataset_name: str, n_mol_per_pack: int = 1, n_pack_per_batch: int = 128,
                    dataset_token: str = None, seed=0, force_save=False, use_cuda=False, use_tqdm=False,
                    use_phi_psi=False, train_only=False
                    ) -> Tuple[GeoMolDataset, MultiGeoMolDataset, MultiGeoMolDataset]:
    assert torch.initial_seed() == seed
    if dataset_token is None:
        pickle_path = f'data/{DATASET_PICKLES_DIR}/{dataset_name}.pickle'
    else:
        pickle_path = f'data/{DATASET_PICKLES_DIR}/{dataset_name}-{dataset_token}.pickle'
    if not force_save and os.path.exists(pickle_path):
        try:
            with open(pickle_path, 'rb') as fp:
                train_dataset, validate_dataset, test_dataset = pickle.load(fp)
            return train_dataset, validate_dataset, test_dataset
        except EOFError:
            pass

    if dataset_name == SupportedMultiDatasets.GEOM_DRUGS_SMALL:
        train_list_mol, validate_list_list_mol, test_list_list_mol = load_drugs(small=True, force_save=force_save)
    elif dataset_name == SupportedMultiDatasets.GEOM_DRUGS:
        train_list_mol, validate_list_list_mol, test_list_list_mol = load_drugs(force_save=force_save)
    elif dataset_name == SupportedMultiDatasets.GEOM_QM9_SMALL:
        train_list_mol, validate_list_list_mol, test_list_list_mol = load_qm9(small=True, force_save=force_save)
    elif dataset_name == SupportedMultiDatasets.GEOM_QM9:
        train_list_mol, validate_list_list_mol, test_list_list_mol = load_qm9(force_save=force_save)
    else:
        assert False

    train_list_smiles: List[str] = [Chem.MolToSmiles(mol) for mol in train_list_mol]
    validate_list_smiles: List[str] = [Chem.MolToSmiles(list_mol[0]) for list_mol in validate_list_list_mol]
    test_list_smiles: List[str] = [Chem.MolToSmiles(list_mol[0]) for list_mol in test_list_list_mol]

    def generate_dataset(list_mol: List[Any], list_smiles: List[str], name='train'):
        perr(f'\tGenerating {name} dataset...')
        indices_each_pack = split_by_interval(len(list_mol), n_mol_per_pack)
        perr(f'\t\tPacking molecular graphs...')
        list_packed_mol_graph = [PackedMolGraph([list_mol[idx] for idx in indices]) for indices in indices_each_pack]
        perr(f'\t\tDeleting bad molecular graphs...')
        indices_each_pack = [[indices_each_pack[i][0] + j for j in list_packed_mol_graph[i].mask]
                             for i in range(len(indices_each_pack))]
        perr(f'\t\tLoading SMILES...')
        list_smiles_set = [[list_smiles[idx] for idx in indices] for indices in indices_each_pack]

        perr(f'\t\tLoading CREST geometries...')
        list_dft_geometry = [torch.FloatTensor(np.vstack([get_positions(list_mol[idx]) for idx in indices]))
                             for indices in indices_each_pack]
        perr(f'\t\tEmbedding RDKit geometries...')
        list_rdkit_geometry = [torch.FloatTensor(np.vstack([rdkit_positions(list_mol[idx]) for idx in indices]))
                               for indices in indices_each_pack]

        perr(f'\t\tPacking dataset...')
        kwargs = {}
        if use_phi_psi:
            kwargs['phi'] = 1
            kwargs['psi'] = 1
        dataset = GeoMolDataset(list_packed_mol_graph=list_packed_mol_graph,
                                list_smiles_set=list_smiles_set,
                                list_properties=None,
                                list_dft_geometry=list_dft_geometry,
                                list_rdkit_geometry=list_rdkit_geometry,
                                use_cuda=use_cuda, **kwargs)
        perr('\t\tFinished')
        return dataset

    def generate_multi_dataset(list_list_mol: List[List[Any]], list_smiles: List[str], name='validate'):
        perr(f'\tGenerating {name} multi dataset...')
        n_mol = len(list_list_mol)
        perr(f'\t\tPacking molecular graphs...')
        list_packed_mol_graph = [PackedMolGraph([list_list_mol[i][0]]) for i in range(n_mol)]
        perr(f'\t\tDeleting bad molecular graphs...')
        good_indices = [i for i in range(n_mol) if len(list_packed_mol_graph[i].mask)]
        list_packed_mol_graph = [list_packed_mol_graph[i] for i in good_indices]
        list_list_mol = [list_list_mol[i] for i in good_indices]
        perr(f'\t\tLoading SMILES...')
        list_smiles = [list_smiles[i] for i in good_indices]
        perr(f'\t\tLoading CREST geometries...')
        list_list_target_conf = [[torch.FloatTensor(get_positions(mol)) for mol in list_mol]
                                 for list_mol in list_list_mol]
        perr(f'\t\tEmbedding RDKit geometries...')
        list_list_rdkit_conf = [
            [torch.FloatTensor(rdkit_positions(list_mol[0], seed=i)) for i in range(len(list_mol) * GENERATE_RATE)]
            for list_mol in list_list_mol]

        kwargs = {}
        # if use_phi_psi:
        #     kwargs['phi'] = 1
        #     kwargs['psi'] = 1
        dataset = MultiGeoMolDataset(list_packed_mol_graph=list_packed_mol_graph,
                                     list_smiles=list_smiles,
                                     list_list_target_conf=list_list_target_conf,
                                     list_list_rdkit_conf=list_list_rdkit_conf,
                                     use_cuda=use_cuda, use_tqdm=use_tqdm, **kwargs)
        perr('\t\tFinished')
        return dataset

    train_dataset = generate_dataset(train_list_mol, train_list_smiles, name='train')
    if train_only:
        validate_dataset = test_dataset = None
    else:
        validate_dataset = generate_multi_dataset(validate_list_list_mol, validate_list_smiles, name='validate')
        test_dataset = generate_multi_dataset(test_list_list_mol, test_list_smiles, name='test')

    if not os.path.isdir(f'data/{DATASET_PICKLES_DIR}'):
        os.mkdir(f'data/{DATASET_PICKLES_DIR}')
    with open(pickle_path, 'wb+') as fp:
        pickle.dump((train_dataset, validate_dataset, test_dataset), fp)

    return train_dataset, validate_dataset, test_dataset
