'''MISATO, a database for protein-ligand interactions
    Copyright (C) 2023  
                        Till Siebenmorgen  (till.siebenmorgen@helmholtz-munich.de)
                        Sabrina Benassou   (s.benassou@fz-juelich.de)
                        Filipe Menezes     (filipe.menezes@helmholtz-munich.de)
                        Erinç Merdivan     (erinc.merdivan@helmholtz-munich.de)

    This library is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 2.1 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public
    License along with this library; if not, write to the Free Software 
    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA'''

import os

from pytorch_lightning import LightningDataModule
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T

from torch.utils.data.distributed import DistributedSampler

import torch
import numpy as np
from torch_geometric.data import Data
from dataset.mol_protein_dataset import MolProtDataset, MDTransform
from dataset.AdK_dataset import MDAnalysisDataset, collate_mda
from dataset.md17_dataset import MD17Dataset

class MDDataModule(LightningDataModule):
    """A DataModule implements 4 key methods:

        def setup(self, stage):
            # things to do on every process
            # load data, set variables, etc...
        def train_dataloader(self):
            # return train dataloader
        def val_dataloader(self):
            # return validation dataloader
        def test_dataloader(self):
            # return test dataloader
    """

    def __init__(
        self,
        files_root: str,
        h5file = "h5_files/MD_dataset_soft_hard_noH.hdf5",
        train = "splits/train_soft_hard.txt",
        val = "splits/val_soft_hard.txt",
        test = "splits/test_soft_hard.txt",
        batch_size = 16,
        num_workers = 48,
        transform = T.RandomTranslate(0.05)
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)


        self.files_root = files_root
        self.h5file = h5file

        self.train = train
        self.val = val
        self.test = test

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform
        

    def setup(self, stage=None):
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
        """

        self.data_train = MolProtDataset(os.path.join(self.files_root, self.h5file), os.path.join(self.files_root, self.train))
        self.data_val = MolProtDataset(os.path.join(self.files_root, self.h5file), os.path.join(self.files_root, self.val))
        self.data_test = MolProtDataset(os.path.join(self.files_root, self.h5file), os.path.join(self.files_root, self.test))

    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_train, 
            batch_size=self.batch_size,
            shuffle=True, 
            num_workers=self.num_workers
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_val, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.batch_size,
            shuffle=False, 
            num_workers=self.num_workers
        )

def get_dataloader(args, dist=None, rank=None):
    files_root = ""
    splits = ['train', 'valid', 'test']
    if args.data_set == 'MISATO':
        if args.debug:
            mdh5_file = 'data/MD/h5_files/MD.hdf5'
            train_idx = "data/MD/splits/train_tinyMD.txt"
            val_idx = "data/MD/splits/val_tinyMD.txt"
            test_idx = "data/MD/splits/test_tinyMD.txt"
        else:
            mdh5_file = 'data/MD/h5_files/MD.hdf5'
            train_idx = "data/MD/splits/train_MD.txt"
            val_idx = "data/MD/splits/val_MD.txt"
            test_idx = "data/MD/splits/test_MD.txt"
        idx_lists = [train_idx, val_idx, test_idx]
        datasets = {split: MolProtDataset(mdh5_file, idx_file) for split, idx_file in zip(splits, idx_lists)}
        if dist is not None:
            samplers = {split: DistributedSampler(dataset, num_replicas=dist.get_world_size(),
                                                  rank=rank, shuffle=True if (split == 'train') else False,
                                                  seed=args.global_seed)
                        for split, dataset in datasets.items()}
            dataloaders = {split: DataLoader(dataset,
                                             batch_size=int(args.global_batch_size // dist.get_world_size()),
                                             num_workers=args.num_workers,
                                             sampler=samplers[split],
                                             drop_last=True if (split == 'train') else False,
                                             pin_memory=False, )
                           for split, dataset in datasets.items()}
        else:
            dataloaders = {split: DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True if (split == 'train') else False,
                                             drop_last=True if (split == 'train') else False,
                                             num_workers=args.num_workers, )
                           for split, dataset in datasets.items()}
            samplers = None
    elif args.data_set == 'AdK':
        dataset_train = MDAnalysisDataset('adk', partition='train', tmp_dir=args.data_dir,
                                          delta_frame=args.delta_frame, load_cached=args.load_cached,
                                          backbone=args.backbone)
        dataset_val = MDAnalysisDataset('adk', partition='valid', tmp_dir=args.data_dir,
                                        delta_frame=args.delta_frame, load_cached=args.load_cached,
                                        backbone=args.backbone)
        dataset_test = MDAnalysisDataset('adk', partition='test', tmp_dir=args.data_dir,
                                         delta_frame=args.delta_frame, load_cached=args.load_cached,
                                         test_rot=False, test_trans=False,
                                         backbone=args.backbone)
        datasets = {'train': dataset_train, 'valid': dataset_val, 'test': dataset_test}
        if dist is not None:
            samplers = {split: DistributedSampler(dataset, num_replicas=dist.get_world_size(),
                                                  rank=rank, shuffle=True if (split == 'train') else False,
                                                  seed=args.global_seed)
                        for split, dataset in datasets.items()}
            dataloaders = {split:  torch.utils.data.DataLoader(dataset,
                                             batch_size=int(args.global_batch_size // dist.get_world_size()),
                                             num_workers=args.num_workers,
                                             sampler=samplers[split],
                                             drop_last=True if (split == 'train') else False,
                                             pin_memory=False, )
                           for split, dataset in datasets.items()}
        else:
            dataloaders = {split:  torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True if (split == 'train') else False,
                                             drop_last=True if (split == 'train') else False,
                                             num_workers=args.num_workers)
                           for split, dataset in datasets.items()}
            samplers = None
    elif 'MD17' in args.data_set:
        args.mol = args.data_set[5:]
        max_training_samples = 10000
        dataset_train = MD17Dataset(partition='train', max_samples=max_training_samples, data_dir=args.data_dir,
                                    molecule_type=args.mol, delta_frame=args.delta_frame)

        dataset_val = MD17Dataset(partition='val', max_samples=2000, data_dir=args.data_dir,
                                  molecule_type=args.mol, delta_frame=args.delta_frame)

        dataset_test = MD17Dataset(partition='test', max_samples=2000, data_dir=args.data_dir,
                                   molecule_type=args.mol, delta_frame=args.delta_frame)

        datasets = {'train': dataset_train, 'valid': dataset_val, 'test': dataset_test}
        if dist is not None:
            samplers = {split: DistributedSampler(dataset, num_replicas=dist.get_world_size(),
                                                  rank=rank, shuffle=True if (split == 'train') else False,
                                                  seed=args.global_seed)
                        for split, dataset in datasets.items()}
            dataloaders = {split:  torch.utils.data.DataLoader(dataset,
                                             batch_size=int(args.global_batch_size // dist.get_world_size()),
                                             num_workers=args.num_workers,
                                             sampler=samplers[split],
                                             drop_last=True if (split == 'train') else False,
                                             pin_memory=False, )
                           for split, dataset in datasets.items()}
        else:
            dataloaders = {split:  torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True if (split == 'train') else False,
                                             drop_last=True if (split == 'train') else False,
                                             num_workers=args.num_workers)
                           for split, dataset in datasets.items()}
            samplers = None
    elif 'OC' in args.data_set:
        dataset_train = LmdbDataset({"src": args.data_dir})
        dataset_test = LmdbDataset({"src": "data/s2ef/all/test_id/"})
        args.mol = args.data_set[5:]
        max_training_samples = 10000
        datasets = {'train': dataset_train, 'valid': dataset_test, 'test': dataset_test}
        if dist is not None:
            samplers = {split: DistributedSampler(dataset, num_replicas=dist.get_world_size(),
                                                  rank=rank, shuffle=True if (split == 'train') else False,
                                                  seed=args.global_seed)
                        for split, dataset in datasets.items()}
            dataloaders = {split:  torch.utils.data.DataLoader(dataset,
                                             batch_size=int(args.global_batch_size // dist.get_world_size()),
                                             num_workers=args.num_workers,
                                             sampler=samplers[split],
                                             drop_last=True if (split == 'train') else False,
                                             pin_memory=False, )
                           for split, dataset in datasets.items()}
        else:
            dataloaders = {split:  torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True if (split == 'train') else False,
                                             drop_last=True if (split == 'train') else False,
                                             num_workers=args.num_workers)
                           for split, dataset in datasets.items()}
            samplers = None
    elif 'protein-unit' in args.data_set:
        if args.debug:
            file_list = ['AA_no_smd.npz']
        else:
            file_list = ['AA_no_smd.npz', 'DD_no_smd.npz', 'GG_no_smd.npz', 'KK_no_smd.npz', 'NN_no_smd.npz', 'RR_no_smd.npz', 'VV_no_smd.npz', 'AN_no_smd.npz', 'EE_no_smd.npz', 'HH_no_smd.npz', 'LL_no_smd.npz', 'PP_no_smd.npz', 'SS_no_smd.npz', 'WW_no_smd.npz', 'CC_no_smd.npz', 'FF_no_smd.npz', 'II_no_smd.npz', 'MM_no_smd.npz', 'QQ_no_smd.npz', 'TT_no_smd.npz', 'YY_no_smd.npz']
        for file in file_list:
            data = np.load(os.path.join(args.data_dir, file))
            dataset = Data(data['Z'], pos=data['R'], ids=data['id'], y=data['E'], f=data['F'], n=data['N'])
        datasets = {'train': dataset, 'valid': dataset, 'test': dataset}
        if dist is not None:
            samplers = {split: DistributedSampler(dataset, num_replicas=dist.get_world_size(),
                                                  rank=rank, shuffle=True if (split == 'train') else False,
                                                  seed=args.global_seed)
                        for split, dataset in datasets.items()}
            dataloaders = {split: DataLoader(dataset,
                                             batch_size=int(args.global_batch_size // dist.get_world_size()),
                                             num_workers=args.num_workers,
                                             sampler=samplers[split],
                                             drop_last=True if (split == 'train') else False,
                                             pin_memory=False, )
                           for split, dataset in datasets.items()}
        else:
            dataloaders = {split: DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True if (split == 'train') else False,
                                             drop_last=True if (split == 'train') else False,
                                             num_workers=args.num_workers, )
                           for split, dataset in datasets.items()}
            samplers = None
    else:
        dataloaders, samplers = None, None
    return dataloaders, samplers

def get_datasets(args):
    # if args.data_set == 'MISATO':
    if args.debug:
        mdh5_file = 'data/MD/h5_files/MD.hdf5'
        train_idx = "data/MD/splits/train_tinyMD.txt"
        val_idx = "data/MD/splits/val_tinyMD.txt"
        test_idx = "data/MD/splits/test_tinyMD.txt"
    else:
        mdh5_file = 'data/MD/h5_files/MD.hdf5'
        train_idx = "data/MD/splits/train_MD.txt"
        val_idx = "data/MD/splits/val_MD.txt"
        test_idx = "data/MD/splits/test_MD.txt"
    idx_lists = [train_idx, val_idx, test_idx]
    splits = ['train', 'valid', 'test']
    datasets = {split: MolProtDataset(mdh5_file, idx_file) for split, idx_file in zip(splits, idx_lists)}
    return datasets


if __name__ == "__main__":
    _ = MDDataModule()
