import os
import json
import numpy as np
import tqdm
import pickle
import torch
import rdkit.Chem as Chem

from typing import List, Tuple, Dict, Any

from torch.utils.data import random_split, Dataset, Subset
from rdkit.Chem import AllChem
# from rdkit.Chem.rdchem.Mol import Mol

from data.config import *
from data.structures import PackedMolGraph
from data.utils import split_by_interval, get_mean_std
from data.phi_psi import get_phi, get_psi
from data.qm7.load_qm7 import load_qm7
from data.qm8.load_qm8 import load_qm8
from data.qm9.load_qm9 import load_qm9

LIST_SIGMA = [0.1, 0.2, 0.4, 0.8, 1.6]


class GeoMolDataset(Dataset):
    def __init__(self,
                 list_packed_mol_graph: List[PackedMolGraph],
                 list_smiles_set: List[List[str]],
                 list_properties: List[torch.Tensor] = None,
                 list_dft_geometry: List[torch.FloatTensor] = None,
                 list_rdkit_geometry: List[torch.FloatTensor] = None,
                 use_cuda=False,
                 use_tqdm=False,
                 **kwargs):
        super(GeoMolDataset, self).__init__()
        self.n_pack = len(list_packed_mol_graph)
        assert len(list_smiles_set) == self.n_pack
        if list_properties is not None:
            assert len(list_properties) == self.n_pack
        if list_dft_geometry is not None:
            assert len(list_dft_geometry) == self.n_pack
        if list_rdkit_geometry is not None:
            assert len(list_rdkit_geometry) == self.n_pack

        self.list_packed_mol_graph = list_packed_mol_graph
        self.list_smiles_set = list_smiles_set
        self.list_properties = list_properties
        self.list_dft_geometry = list_dft_geometry
        self.list_rdkit_geometry = list_rdkit_geometry
        self.list_extra_dict: List[Dict[str, Any]] = [{} for _ in range(self.n_pack)]

        list_geometry = self.list_dft_geometry if self.list_dft_geometry is not None else self.list_rdkit_geometry

        if 'disturb' in kwargs.keys() and kwargs['disturb']:
            assert list_geometry is not None
            for i in range(self.n_pack):
                self.list_extra_dict[i]['disturb_geometries'] = []
                for sigma in LIST_SIGMA:
                    disturb_geometry = list_geometry[i] + torch.normal(0, sigma, size=list_geometry[i].shape)
                    self.list_extra_dict[i]['disturb_geometries'].append(disturb_geometry)

        if 'phi' in kwargs.keys() and kwargs['phi']:
            assert list_geometry is not None
            t = tqdm.tqdm(range(self.n_pack), total=self.n_pack) if use_tqdm else range(self.n_pack)
            for i in t:
                vew1, vew2, flat_abc_indices, g_abc = get_phi(self.list_packed_mol_graph[i].mask_matrices,
                                                              list_geometry[i], use_cuda=False)
                self.list_extra_dict[i]['phi_w1'] = vew1
                self.list_extra_dict[i]['phi_w2'] = vew2
                self.list_extra_dict[i]['phi_flat'] = flat_abc_indices
                self.list_extra_dict[i]['phi_g'] = g_abc

        if 'psi' in kwargs.keys() and kwargs['psi']:
            assert list_geometry is not None
            t = tqdm.tqdm(range(self.n_pack), total=self.n_pack) if use_tqdm else range(self.n_pack)
            for i in t:
                vew1, vew2, flat_abcdef_indices, g_abcdef = get_psi(self.list_packed_mol_graph[i].mask_matrices,
                                                                    list_geometry[i], use_cuda=False)
                self.list_extra_dict[i]['psi_w1'] = vew1
                self.list_extra_dict[i]['psi_w2'] = vew2
                self.list_extra_dict[i]['psi_flat'] = flat_abcdef_indices
                self.list_extra_dict[i]['psi_g'] = g_abcdef

    def __getitem__(self, index) -> Tuple[PackedMolGraph, List[str], torch.Tensor,
                                          torch.FloatTensor, torch.FloatTensor, Dict[str, Any]]:
        return self.list_packed_mol_graph[index], self.list_smiles_set[index], \
               self.list_properties[index] if self.list_properties is not None else None, \
               self.list_dft_geometry[index] if self.list_dft_geometry is not None else None, \
               self.list_rdkit_geometry[index] if self.list_rdkit_geometry is not None else None, \
               self.list_extra_dict[index]

    def __len__(self):
        return self.n_pack


class SupportedDatasets:
    QM7 = 'qm7'
    QM8 = 'qm8'
    QM9 = 'qm9'


def dft_mol_positions(mol) -> np.ndarray:
    return mol.GetConformer().GetPositions()


def rdkit_mol_positions(mol, seed=0) -> np.ndarray:
    # ATTENTION: This operation will overwrite DFT geometry!
    position = np.zeros([len(mol.GetAtoms()), 3], np.float)
    try:
        AllChem.EmbedMolecule(mol, randomSeed=seed)
        conf = mol.GetConformer()
        position = conf.GetPositions()
    except ValueError:
        pass
    return position


def dataset_is_geometrical(dataset_name: str):
    if dataset_name in [SupportedDatasets.QM7, SupportedDatasets.QM8, SupportedDatasets.QM9]:
        return True
    return False


def dataset_is_regressive(dataset_name: str):
    if dataset_name in [SupportedDatasets.QM7, SupportedDatasets.QM8, SupportedDatasets.QM9]:
        return True
    return False


def load_data(dataset_name: str, n_mol_per_pack: int = 1, n_pack_per_batch: int = 128,
              dataset_token: str = None, seed=0, force_save=False, use_cuda=False, use_phi_psi=False, use_disturb=False,
              use_tqdm=False
              ) -> Tuple[GeoMolDataset, GeoMolDataset, GeoMolDataset, torch.Tensor]:
    assert torch.initial_seed() == seed
    if dataset_token is None:
        pickle_path = f'data/{DATASET_PICKLES_DIR}/{dataset_name}.pickle'
    else:
        pickle_path = f'data/{DATASET_PICKLES_DIR}/{dataset_name}-{dataset_token}.pickle'
    if not force_save and os.path.exists(pickle_path):
        try:
            with open(pickle_path, 'rb') as fp:
                train_dataset, validate_dataset, test_dataset, properties = pickle.load(fp)
            return train_dataset, validate_dataset, test_dataset, properties
        except EOFError:
            pass

    if dataset_name == SupportedDatasets.QM7:
        molecules, properties = load_qm7(force_save=force_save)
        list_smiles: List[str] = [Chem.MolToSmiles(mol) for mol in molecules]
    elif dataset_name == SupportedDatasets.QM8:
        molecules, properties = load_qm8(force_save=force_save)
        list_smiles: List[str] = [Chem.MolToSmiles(mol) for mol in molecules]
    elif dataset_name == SupportedDatasets.QM9:
        molecules, properties = load_qm9(force_save=force_save)
        list_smiles: List[str] = [Chem.MolToSmiles(mol) for mol in molecules]
    else:
        assert False
    assert len(molecules) == properties.shape[0]
    properties = torch.FloatTensor(properties)

    split_file = f'data/{dataset_name}/split-{seed}.json'
    if os.path.exists(split_file):
        with open(split_file) as fp:
            split_dict = json.load(fp)
        train_indices, validate_indices, test_indices = split_dict['train'], split_dict['validate'], split_dict['test']
        given_list = train_indices + validate_indices + test_indices
        indices_each_pack = split_by_interval(len(given_list), n_mol_per_pack, given_list=given_list)
        print(indices_each_pack[-1])
    else:
        indices_each_pack = split_by_interval(len(molecules), n_mol_per_pack)

    if use_tqdm:
        indices_each_pack_ = tqdm.tqdm(indices_each_pack, total=len(indices_each_pack))
    else:
        indices_each_pack_ = indices_each_pack
    print(f'\t\tPacking molecular graphs...')
    list_packed_mol_graph = [PackedMolGraph([molecules[idx] for idx in indices]) for indices in indices_each_pack_]

    if os.path.exists(split_file):
        indices_each_pack = [[indices_each_pack[i][j] for j in list_packed_mol_graph[i].mask]
                             for i in range(len(indices_each_pack))]
    else:
        indices_each_pack = [[indices_each_pack[i][0] + j for j in list_packed_mol_graph[i].mask]
                             for i in range(len(indices_each_pack))]

    list_smiles_set = [[list_smiles[idx] for idx in indices] for indices in indices_each_pack]
    list_properties = [properties[indices, :] for indices in indices_each_pack]
    properties = torch.vstack(list_properties)

    if dataset_is_regressive(dataset_name):
        mean, std = get_mean_std(properties)
        list_properties = [(p - mean) / std for p in list_properties]

    if dataset_is_geometrical(dataset_name):
        if use_tqdm:
            indices_each_pack_ = tqdm.tqdm(indices_each_pack, total=len(indices_each_pack))
        else:
            indices_each_pack_ = indices_each_pack
        print(f'\t\tLoading DFT geometries...')
        list_dft_geometry = [np.vstack([dft_mol_positions(molecules[idx]) for idx in indices])
                             for indices in indices_each_pack_]
        list_dft_geometry = [torch.FloatTensor(dft_geometry) for dft_geometry in list_dft_geometry]
    else:
        list_dft_geometry = None

    if use_tqdm:
        indices_each_pack_ = tqdm.tqdm(indices_each_pack, total=len(indices_each_pack))
    else:
        indices_each_pack_ = indices_each_pack
    print(f'\t\tEmbedding RDKit geometries...')
    list_rdkit_geometry = [np.vstack([rdkit_mol_positions(molecules[idx]) for idx in indices])
                           for indices in indices_each_pack_]
    list_rdkit_geometry = [torch.FloatTensor(rdkit_geometry) for rdkit_geometry in list_rdkit_geometry]

    kwargs = {}
    if use_phi_psi:
        kwargs['phi'] = 1
        kwargs['psi'] = 1
    if use_disturb:
        kwargs['disturb'] = 1
    dataset = GeoMolDataset(list_packed_mol_graph=list_packed_mol_graph,
                            list_smiles_set=list_smiles_set,
                            list_properties=list_properties,
                            list_dft_geometry=list_dft_geometry,
                            list_rdkit_geometry=list_rdkit_geometry,
                            use_cuda=use_cuda, use_tqdm=use_tqdm, **kwargs)
    n_pack = len(dataset)
    n_validate_pack = int(n_pack * 0.1)
    n_test_pack = int(n_pack * 0.1)
    n_train_pack = n_pack - n_validate_pack - n_test_pack
    train_dataset, validate_dataset, test_dataset = random_split(
        dataset=dataset, lengths=[n_train_pack, n_validate_pack, n_test_pack])

    if not os.path.isdir(f'data/{DATASET_PICKLES_DIR}'):
        os.mkdir(f'data/{DATASET_PICKLES_DIR}')
    with open(pickle_path, 'wb+') as fp:
        pickle.dump((train_dataset, validate_dataset, test_dataset, properties), fp)

    return train_dataset, validate_dataset, test_dataset, properties
