import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from src.datasets.mix_dataset import MixDataset
from src.datasets.pair_dataset import PairDataset
from src.datasets.all_test_dataset import AllTestDataset
from src.datasets.pdb_dataset import PDBDataset
from src.datasets.ptm_dataset import PTMDataset
from src.datasets.feature_prediction_dataset import FeaturePredictionDataset

class BioDataLoader(DataLoader):
    def __init__(self, dataset, num_workers=8, *args, **kwargs):
        super(BioDataLoader, self).__init__(dataset, num_workers=num_workers, *args, **kwargs)
        self.pretrain_device = 'cuda:0'
        
    def __iter__(self):
        for batch in super().__iter__():
            try:
                self.pretrain_device = f'cuda:{torch.distributed.get_rank()}'
            except:
                self.pretrain_device = 'cuda:0'
            yield batch


def memory_efficient_collate_fn(batch):
    batch = [one for one in batch if one is not None and len(one['X']) > 0]
    if len(batch) == 0:
        return None
    num_nodes = torch.tensor([one['num_nodes'] for one in batch])
    shift = num_nodes.cumsum(dim=0)
    shift = torch.cat([torch.tensor([0], device=shift.device), shift], dim=0)

    ret = {}
    for key in batch[0].keys():
        if key in ['edge_idx']:
            ret[key] = torch.cat([one[key] + shift[idx] for idx, one in enumerate(batch)], dim=1)
        elif key in ['batch_id']:
            ret[key] = torch.cat([one[key] + idx for idx, one in enumerate(batch)])
        elif type(batch[0][key])== torch.Tensor:
            ret[key] = torch.cat([one[key] for one in batch], dim=0)
        elif type(batch[0][key])== str:
            ret[key] = [one[key] for one in batch]
        else:
            ret[key] = torch.tensor([one[key] for one in batch])
    if len(ret) == 0:
        return None
    return ret


class DInterface(pl.LightningDataModule):
    def __init__(self, data_path, batch_size, num_workers, dataset_type='MixDataset', remove_pdb=False, **kwargs):
        super().__init__()
        self.dataset_type = dataset_type

        self.data_module = {'MixDataset': MixDataset, 'AllTestDataset': AllTestDataset, 'PairDataset': PairDataset, 'PDBDataset': PDBDataset,'FeaturePredictionDataset':FeaturePredictionDataset,'PTMDataset': PTMDataset}.get(dataset_type)
        self.data_path = data_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        if dataset_type == 'MixDataset':
            self.dataset = self.data_module(data_path=data_path, split='train', remove_pdb=remove_pdb)
        elif dataset_type == 'AllTestDataset':
            self.dataset = self.data_module(data_path=data_path, split='test')
        else:
            self.dataset = self.data_module(data_path=data_path, split='train')

    def setup(self, stage=None):
        if self.dataset_type == 'MixDataset':
            if stage == "fit" or stage is None:
                self.trainset = self.data_module(
                    data_path=self.data_path, split="train",
                    pdb_data=self.dataset.pdb_data["train"],
                    afdb_data=self.dataset.afdb_data["train"],
                )
                self.valset = self.data_module(
                    data_path=self.data_path, split="val",
                    pdb_data=self.dataset.pdb_data["val"],
                    afdb_data=self.dataset.afdb_data["val"],
                )
            if stage == "test" or stage is None:
                self.testset = self.data_module(
                    data_path=self.data_path, split="test",
                    pdb_data=self.dataset.pdb_data["test"],
                    afdb_data=self.dataset.afdb_data["test"],
                )
        elif self.dataset_type == 'PairDataset':
            if stage == "fit" or stage is None:
                self.trainset = self.data_module(
                    data_path=self.data_path, split="train",
                    data=self.dataset.all_data["train"]
                )
                self.valset = self.data_module(
                    data_path=self.data_path, split="val",
                    data=self.dataset.all_data["val"]
                )
            if stage == "test" or stage is None:
                self.testset = self.data_module(
                    data_path=self.data_path, split="test",
                    data=self.dataset.all_data["test"]
                )
            if stage == "predict" or stage is None:
                self.trainset = self.data_module(
                    data_path=self.data_path, split="train",
                    data=self.dataset.all_data["train"]
                )
                self.valset = self.data_module(
                    data_path=self.data_path, split="val",
                    data=self.dataset.all_data["val"]
                )
                self.testset = self.data_module(
                    data_path=self.data_path, split="test",
                    data=self.dataset.all_data["test"]
                )
        elif self.dataset_type == 'AllTestDataset':
            if True: # AllTest, unable to fit/val
                self.trainset = self.data_module(
                    data_path=self.data_path, split="test"
                )
                self.valset = self.data_module(
                    data_path=self.data_path, split="test"
                )
                self.testset = self.data_module(
                    data_path=self.data_path, split="test"
                )
        elif self.dataset_type == 'PDBDataset':
            if stage == "fit" or stage is None: # use pair_dataset's val/test as val/test
                self.trainset = self.data_module(
                    data_path=self.data_path, split="train"
                )
            
                self.valset = self.data_module(
                    data_path=self.data_path, split="val"
                )
            if stage == "test":
                self.testset = self.data_module(
                    data_path=self.data_path, split="test"
                )
            if stage == "predict":
                self.trainset = self.data_module(
                    data_path=self.data_path, split="train",
                    data=self.dataset.all_data["train"]
                )
                self.valset = self.data_module(
                    data_path=self.data_path, split="val",
                    data=self.dataset.all_data["val"]
                )
                self.testset = self.data_module(
                    data_path=self.data_path, split="test",
                    data=self.dataset.all_data["test"]
                )
        elif self.dataset_type == 'FeaturePredictionDataset':
            if stage == "fit" or stage is None: # use pair_dataset's val/test as val/test
                self.trainset = self.data_module(
                    data_path=self.data_path, split="train"
                )
            
                self.valset = self.data_module(
                    data_path=self.data_path, split="val"
                )
            if stage == "test":
                self.testset = self.data_module(
                    data_path=self.data_path, split="test"
                )
            if stage == "predict":
                self.trainset = self.data_module(
                    data_path=self.data_path, split="train",
                    data=self.dataset.all_data["train"]
                )
                self.valset = self.data_module(
                    data_path=self.data_path, split="val",
                    data=self.dataset.all_data["val"]
                )
                self.testset = self.data_module(
                    data_path=self.data_path, split="test",
                    data=self.dataset.all_data["test"]
                )
        elif self.dataset_type == 'PTMDataset':
            if stage == "fit" or stage is None: # use pair_dataset's val/test as val/test
                self.trainset = self.data_module(
                    data_path=self.data_path, split="train"
                )
            
                self.valset = self.data_module(
                    data_path=self.data_path, split="val"
                )
            if stage == "test":
                self.testset = self.data_module(
                    data_path=self.data_path, split="test"
                )
            if stage == "predict":
                self.trainset = self.data_module(
                    data_path=self.data_path, split="train",
                    data=self.dataset.all_data["train"]
                )
                self.valset = self.data_module(
                    data_path=self.data_path, split="val",
                    data=self.dataset.all_data["val"]
                )
                self.testset = self.data_module(
                    data_path=self.data_path, split="test",
                    data=self.dataset.all_data["test"]
                )

    def create_dataloader(self, dataset, is_train=False):
        return BioDataLoader(
            dataset, 
            batch_size=self.batch_size, 
            shuffle=is_train,
            num_workers=self.num_workers,  
            prefetch_factor=4, 
            pin_memory=True,
            drop_last=False,
            collate_fn=memory_efficient_collate_fn, 
        )

    def train_dataloader(self):
        if self.dataset_type == "MixDataset":
            self.trainset._update_afdb()
        '''
        if self.dataset_type == "PDBDataset":
           return self.create_dataloader(self.trainset, is_train=False) # an iterable dataset does not use shuffle
        '''
        return self.create_dataloader(self.trainset, is_train=True)

    def val_dataloader(self):
        return self.create_dataloader(self.valset)

    def test_dataloader(self):
        return self.create_dataloader(self.testset)    

    def predict_dataloader(self):
        train_dataloader = self.create_dataloader(self.trainset, is_train=False)
        val_dataloader = self.create_dataloader(self.valset, is_train=False)
        test_dataloader = self.create_dataloader(self.testset, is_train=False)
        return train_dataloader
        # if self.dataset_type in ["FeaturePredictionDataset",'PTMDataset']:
        #     return train_dataloader
        # return [train_dataloader, val_dataloader, test_dataloader]    