import pandas as pd
import torch
import random
from omegaconf import DictConfig
import hydra
from typing import TypeVar, Optional, Union, Iterable, Sequence, Callable, Any
from pytorch_lightning import LightningDataModule
from torch import Generator
from torch.utils.data import DataLoader, Sampler
from affinityenhancer.data.collate.gearnet_collate \
    import gearnet_collate, gearnet_positions_collate, gearnet_paired_collate, gearnet_annotate_collate,\
    paired_collate, onehot_collate
from affinityenhancer.data.datasets.gearnet_dataset \
    import GearNetPairedAbStructureSequenceDataset,\
    GearNetHeavyLightStructureSequenceEdgesDataset, \
    GearNetHeavyLightStructureSurfSequenceDataset, GearNetHeavyLightAnnotatedDataset,\
    GearNetSequenceEdgesDataset
from affinityenhancer.data.datasets.paired_input_gearnet_dataset \
    import GearNetPairedDataset, PairedDataset

T = TypeVar("T")

class GearnetSequenceDecoderDataModule(LightningDataModule):
    def __init__(
        self,
        datapath,
        s3_pdb_path,
        data_dir,
        *,
        seed: int = 0xDEADBEEF,
        batch_size: int = 32,
        shuffle: bool = True,
        sampler: Optional[Union[Iterable, Sampler]] = None,
        batch_sampler: Optional[Union[Iterable[Sequence], Sampler[Sequence]]] = None,
        num_workers: int = 12,
        collate_fn: Optional[
            Callable[['list[T]'], Any]
        ] = None,
        pin_memory: bool = True,
        drop_last: bool = False,
        columns: list[str] = ['seqid'],
        weight_column: Optional[str] = "weight",
        partition_column: Optional[str] = "partition",
        split: Optional[str] = "iid",
        additional_columns: list[str] = []
    ) -> None:
        """
        
        """
        super().__init__()

        self.datapath = datapath
        self.s3_pdb_path = s3_pdb_path
        self.data_dir = data_dir

        generator = Generator().manual_seed(seed)

        self._generator = generator
        self._seed = seed
        self._batch_size = batch_size
        self._shuffle = shuffle
        self._sampler = sampler
        self._batch_sampler = batch_sampler
        self.columns = list(columns)
        self._weight_column = weight_column
        self._partition_column = partition_column
        self._split = split
        self._additional_columns = additional_columns

        self._num_workers = num_workers
        if collate_fn is None:
            self._collate_fn = gearnet_collate
        
        self._pin_memory = pin_memory
        self._drop_last = drop_last

        self._df = None

    def prepare_data(self) -> None:
        # Load in Dataset, transform sequences
        self._additional_columns += [self._partition_column] if self._weight_column is not None else [self._partition_column]
        
        if self.datapath.endswith('.parquet'):
            self._df = pd.read_parquet(self.datapath)
        else:
            self._df = pd.read_csv(self.datapath)
        print(self._df.columns)
        print(self._df.shape[0])
    

    def setup(self, stage: str = "fit") -> None:  # noqa: ARG002
 
        random.seed(self._seed)
        torch.manual_seed(self._seed)

        if self._df is None:
            self.prepare_data()

        # preassign splits based on sequence identity
        if stage == "fit":
            if self._split == 'iid':
                self._partition_column = 'iid_partition'
                self._df['iid_partition'] = 'train'
                self._df['iid_partition'][:int(self._df.shape[0]*0.05)] = 'val'
            
            df = self._df[self._df[self._partition_column]=='train']
            self._train_dataset = GearNetPairedAbStructureSequenceDataset(df,
                                                                          s3_pdb_path=self.s3_pdb_path,
                                                                          data_dir=self.data_dir)

            self._train_sampler = None
            self._val_sampler = None
            
            df = self._df[self._df[self._partition_column]=='val']
            self._val_dataset = GearNetPairedAbStructureSequenceDataset(df,
                                                                        s3_pdb_path=self.s3_pdb_path,
                                                                        data_dir=self.data_dir)
       
        df = self._df[self._df[self._partition_column]=='test']
        self._test_dataset = None
        if not df.empty:
            self._test_dataset = GearNetPairedAbStructureSequenceDataset(df,
                                                                         s3_pdb_path=self.s3_pdb_path,
                                                                         data_dir=self.data_dir)
 

        if stage == "predict":
            self._predict_dataset = GearNetPairedAbStructureSequenceDataset(self._df,
                                                                            s3_pdb_path=self.s3_pdb_path,
                                                                            data_dir=self.data_dir)


    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self._train_dataset,
            batch_size=self._batch_size,
            sampler=self._train_sampler,
            num_workers=self._num_workers,
            collate_fn=self._collate_fn,
            pin_memory=self._pin_memory,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self._val_dataset,
            batch_size=self._batch_size,
            shuffle=False,
            sampler=self._val_sampler,
            num_workers=self._num_workers,
            collate_fn=self._collate_fn,
            pin_memory=self._pin_memory,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self._test_dataset,
            batch_size=self._batch_size,
            shuffle=False,
            num_workers=self._num_workers,
            collate_fn=self._collate_fn,
            pin_memory=self._pin_memory,
        )

    def predict_dataloader(self) -> DataLoader:
        return DataLoader(
            self._predict_dataset,
            batch_size=self._batch_size,
            shuffle=False,
            sampler=self._sampler,
            num_workers=self._num_workers,
            collate_fn=self._collate_fn,
            pin_memory=self._pin_memory,
            )


class GearnetHeavyLightSequenceDecoderDataModule(GearnetSequenceDecoderDataModule):
    def __init__(
        self,
        datapath,
        s3_pdb_path,
        data_dir,
        *,
        seed: int = 0xDEADBEEF,
        batch_size: int = 32,
        shuffle: bool = True,
        sampler: Optional[Union[Iterable, Sampler]] = None,
        batch_sampler: Optional[Union[Iterable[Sequence], Sampler[Sequence]]] = None,
        num_workers: int = 12,
        collate_fn: Optional[
            Callable[['list[T]'], Any]
        ] = None,
        pin_memory: bool = True,
        drop_last: bool = False,
        columns: list[str] = ['seqid'],
        weight_column: Optional[str] = "weight",
        partition_column: Optional[str] = "partition",
        split: Optional[str] = "iid",
        additional_columns: list[str] = [],
        edges: bool = False
    ) -> None:
        """
        
        """
        super().__init__(datapath,
                        s3_pdb_path,
                        data_dir,
                        seed=seed,
                        batch_size=batch_size,
                        shuffle=shuffle,
                        sampler=sampler,
                        batch_sampler=batch_sampler,
                        num_workers=num_workers,
                        collate_fn=collate_fn,
                        pin_memory=pin_memory,
                        drop_last=drop_last,
                        columns=columns,
                        weight_column=weight_column,
                        partition_column=partition_column,
                        split=split,
                        additional_columns=additional_columns)
        self.edges = edges

    def setup(self, stage: str = "fit") -> None:  # noqa: ARG002
 
        random.seed(self._seed)
        torch.manual_seed(self._seed)

        if self._df is None:
            self.prepare_data()

        # preassign splits based on sequence identity
        if stage == "fit":
            if self._split == 'iid':
                self._partition_column = 'iid_partition'
                self._df['iid_partition'] = 'train'
                self._df['iid_partition'][:int(self._df.shape[0]*0.05)] = 'val'
            
            df = self._df[self._df[self._partition_column]=='train']
            self._train_dataset = GearNetHeavyLightStructureSequenceEdgesDataset(df,
                                                                          s3_pdb_path=self.s3_pdb_path,
                                                                          data_dir=self.data_dir,
                                                                          edges=self.edges)

            self._train_sampler = None
            self._val_sampler = None
            
            df = self._df[self._df[self._partition_column]=='val']
            self._val_dataset = GearNetHeavyLightStructureSequenceEdgesDataset(df,
                                                                        s3_pdb_path=self.s3_pdb_path,
                                                                        data_dir=self.data_dir,
                                                                          edges=self.edges)
       
        df = self._df[self._df[self._partition_column]=='test']
        self._test_dataset = None
        if not df.empty:
            self._test_dataset = GearNetHeavyLightStructureSequenceEdgesDataset(df,
                                                                         s3_pdb_path=self.s3_pdb_path,
                                                                         data_dir=self.data_dir,
                                                                          edges=self.edges)
 

        if stage == "predict":
            self._predict_dataset = GearNetHeavyLightStructureSequenceEdgesDataset(self._df,
                                                                            s3_pdb_path=self.s3_pdb_path,
                                                                            data_dir=self.data_dir,
                                                                            edges=self.edges)




class GearnetOneHotSequenceDecoderDataModule(GearnetSequenceDecoderDataModule):
    def __init__(
        self,
        datapath,
        s3_pdb_path,
        data_dir,
        *,
        seed: int = 0xDEADBEEF,
        batch_size: int = 32,
        shuffle: bool = True,
        sampler: Optional[Union[Iterable, Sampler]] = None,
        batch_sampler: Optional[Union[Iterable[Sequence], Sampler[Sequence]]] = None,
        num_workers: int = 12,
        collate_fn: Optional[
            Callable[['list[T]'], Any]
        ] = None,
        pin_memory: bool = True,
        drop_last: bool = False,
        columns: list[str] = ['seqid'],
        weight_column: Optional[str] = "weight",
        partition_column: Optional[str] = "partition",
        split: Optional[str] = "iid",
        additional_columns: list[str] = [],
        edges: bool = False
    ) -> None:
        """
        
        """
        super().__init__(datapath,
                        s3_pdb_path,
                        data_dir,
                        seed=seed,
                        batch_size=batch_size,
                        shuffle=shuffle,
                        sampler=sampler,
                        batch_sampler=batch_sampler,
                        num_workers=num_workers,
                        collate_fn=collate_fn,
                        pin_memory=pin_memory,
                        drop_last=drop_last,
                        columns=columns,
                        weight_column=weight_column,
                        partition_column=partition_column,
                        split=split,
                        additional_columns=additional_columns)
        self.edges = edges
        if collate_fn is None:
            self._collate_fn = onehot_collate
        

    def setup(self, stage: str = "fit") -> None:  # noqa: ARG002
 
        random.seed(self._seed)
        torch.manual_seed(self._seed)

        if self._df is None:
            self.prepare_data()

        # preassign splits based on sequence identity
        if stage == "fit":
            if self._split == 'iid':
                self._partition_column = 'iid_partition'
                self._df['iid_partition'] = 'train'
                self._df['iid_partition'][:int(self._df.shape[0]*0.05)] = 'val'
            
            df = self._df[self._df[self._partition_column]=='train']
            self._train_dataset = GearNetSequenceEdgesDataset(df,
                                                            s3_pdb_path=self.s3_pdb_path,
                                                            data_dir=self.data_dir,
                                                            edges=self.edges
                                                            )

            self._train_sampler = None
            self._val_sampler = None
            
            df = self._df[self._df[self._partition_column]=='val']
            self._val_dataset = GearNetSequenceEdgesDataset(df,
                                                            s3_pdb_path=self.s3_pdb_path,
                                                            data_dir=self.data_dir,
                                                            edges=self.edges
                                                            )
       
        df = self._df[self._df[self._partition_column]=='test']
        self._test_dataset = None
        if not df.empty:
            self._test_dataset = GearNetSequenceEdgesDataset(df,
                                                            s3_pdb_path=self.s3_pdb_path,
                                                            data_dir=self.data_dir,
                                                            edges=self.edges
                                                            )
 

        if stage == "predict":
            self._predict_dataset = GearNetSequenceEdgesDataset(self._df,
                                                                s3_pdb_path=self.s3_pdb_path,
                                                                data_dir=self.data_dir,
                                                                edges=self.edges
                                                                )


class GearnetPairedDataModule(GearnetSequenceDecoderDataModule):
    def __init__(
        self,
        datapath_paired: Union[str , list],
        s3_pdb_path: str,
        data_dir,
        *,
        seed: int = 0xDEADBEEF,
        batch_size: int = 32,
        shuffle: bool = True,
        sampler: Optional[Union[Iterable, Sampler]] = None,
        batch_sampler: Optional[Union[Iterable[Sequence], Sampler[Sequence]]] = None,
        num_workers: int = 12,
        collate_fn: Optional[
            Callable[['list[T]'], Any]
        ] = None,
        pin_memory: bool = True,
        drop_last: bool = False,
        columns: list[str] = ['seqid'],
        weight_column: Optional[str] = "weight",
        partition_column: Optional[str] = "partition",
        split: Optional[str] = "iid",
        additional_columns: list[str] = [],
        edges: bool = False,
    ) -> None:
        """
        
        """
        datapath = ""
        super().__init__(datapath,
                        s3_pdb_path,
                        data_dir,
                        seed=seed,
                        batch_size=batch_size,
                        shuffle=shuffle,
                        sampler=sampler,
                        batch_sampler=batch_sampler,
                        num_workers=num_workers,
                        collate_fn=collate_fn,
                        pin_memory=pin_memory,
                        drop_last=drop_last,
                        columns=columns,
                        weight_column=weight_column,
                        partition_column=partition_column,
                        split=split,
                        additional_columns=additional_columns)
        
        self.datapath_paired = datapath_paired
        self._collate_fn = gearnet_paired_collate
        self.edges = edges

    def prepare_data(self) -> None:
        # Load in Dataset, transform sequences
        self._additional_columns += [self._partition_column] if self._weight_column is not None else [self._partition_column]
        
        if isinstance(self.datapath_paired, str):
            self._df_paired = pd.read_parquet(self.datapath_paired)
        else:
            df = []
            for path in self.datapath_paired:
                df.append(pd.read_parquet(path))
            self._df_paired = pd.concat(df).drop_duplicates()
        self._df = pd.DataFrame()
        self._df['seqid'] = self._df_paired['first_seqid'].values.tolist() + \
            self._df_paired['second_seqid'].values.tolist()
        self._df = self._df.drop_duplicates()
        print(self._df_paired.columns)
        print(self._df_paired.shape[0])


    def setup(self, stage: str = "fit") -> None:  # noqa: ARG002
 
        random.seed(self._seed)
        torch.manual_seed(self._seed)

        if self._df is None:
            self.prepare_data()

        # preassign splits based on sequence identity
        if stage == "fit":
            if self._split == 'iid':
                self._partition_column = 'iid_partition'
                self._df_paired[self._partition_column] = 'train'
                self._df_paired[self._partition_column][:int(self._df_paired.shape[0]*0.05)] = 'val'
            
            df_val = self._df_paired[self._df_paired[self._partition_column]=='val']
            if df_val.empty:
                val_size = int(self._df_paired.shape[0]*0.05)
                indices = self._df_paired[self._df_paired[self._partition_column]=='train'].index[:val_size]
                self._df_paired.loc[indices, self._partition_column] = 'val'

            df = self._df_paired[self._df_paired[self._partition_column]=='train']
            self._train_dataset = \
                GearNetPairedDataset(self._df,
                                    df,
                                    s3_pdb_path=self.s3_pdb_path,
                                    data_dir=self.data_dir,
                                    edges = self.edges
                                    )

            self._train_sampler = None
            self._val_sampler = None
            
            df = self._df_paired[self._df_paired[self._partition_column]=='val']
            assert not df.empty
            self._val_dataset = \
                GearNetPairedDataset(self._df,
                                    df,
                                    s3_pdb_path=self.s3_pdb_path,
                                    data_dir=self.data_dir,
                                    edges = self.edges
                                    )
       
        df = self._df_paired[self._df_paired[self._partition_column]=='test']
        self._test_dataset = None
        if not df.empty:
            self._test_dataset = \
                GearNetPairedDataset(self._df,
                                    df,
                                    s3_pdb_path=self.s3_pdb_path,
                                    data_dir=self.data_dir,
                                    edges = self.edges
                                    )
 

        if stage == "predict":
            self._predict_dataset = \
                GearNetPairedDataset(self._df,
                                    df,
                                    s3_pdb_path=self.s3_pdb_path,
                                    data_dir=self.data_dir,
                                    edges = self.edges
                                    )


class PairedDataModule(GearnetSequenceDecoderDataModule):
    def __init__(
        self,
        datapath_paired: Union[str , list],
        s3_pdb_path: str,
        data_dir,
        *,
        seed: int = 0xDEADBEEF,
        batch_size: int = 32,
        shuffle: bool = True,
        sampler: Optional[Union[Iterable, Sampler]] = None,
        batch_sampler: Optional[Union[Iterable[Sequence], Sampler[Sequence]]] = None,
        num_workers: int = 12,
        collate_fn: Optional[
            Callable[['list[T]'], Any]
        ] = None,
        pin_memory: bool = True,
        drop_last: bool = False,
        columns: list[str] = ['seqid'],
        weight_column: Optional[str] = "weight",
        partition_column: Optional[str] = "partition",
        split: Optional[str] = "iid",
        additional_columns: list[str] = [],
        edges: bool = False
    ) -> None:
        """
        
        """
        datapath = ""
        super().__init__(datapath,
                        s3_pdb_path,
                        data_dir,
                        seed=seed,
                        batch_size=batch_size,
                        shuffle=shuffle,
                        sampler=sampler,
                        batch_sampler=batch_sampler,
                        num_workers=num_workers,
                        collate_fn=collate_fn,
                        pin_memory=pin_memory,
                        drop_last=drop_last,
                        columns=columns,
                        weight_column=weight_column,
                        partition_column=partition_column,
                        split=split,
                        additional_columns=additional_columns)
        
        self.datapath_paired = datapath_paired
        self._collate_fn = paired_collate
        self.edges = edges

    def prepare_data(self) -> None:
        # Load in Dataset, transform sequences
        self._additional_columns += [self._partition_column] if self._weight_column is not None else [self._partition_column]
        
        if isinstance(self.datapath_paired, str):
            self._df_paired = pd.read_parquet(self.datapath_paired)
        else:
            df = []
            for path in self.datapath_paired:
                df.append(pd.read_parquet(path))
            self._df_paired = pd.concat(df).drop_duplicates()
        self._df = pd.DataFrame()
        self._df['seqid'] = self._df_paired['first_seqid'].values.tolist() + \
            self._df_paired['second_seqid'].values.tolist()
        self._df = self._df.drop_duplicates()
        print(self._df_paired.columns)
        print(self._df_paired.shape[0])


    def setup(self, stage: str = "fit") -> None:  # noqa: ARG002
 
        random.seed(self._seed)
        torch.manual_seed(self._seed)

        if self._df is None:
            self.prepare_data()

        # preassign splits based on sequence identity
        if stage == "fit":
            if self._split == 'iid':
                self._partition_column = 'iid_partition'
                self._df_paired[self._partition_column] = 'train'
                self._df_paired[self._partition_column][:int(self._df_paired.shape[0]*0.05)] = 'val'
            
            df_val = self._df_paired[self._df_paired[self._partition_column]=='val']
            if df_val.empty:
                val_size = int(self._df_paired.shape[0]*0.05)
                indices = self._df_paired[self._df_paired[self._partition_column]=='train'].index[:val_size]
                self._df_paired.loc[indices, self._partition_column] = 'val'

            df = self._df_paired[self._df_paired[self._partition_column]=='train']
            self._train_dataset = \
                PairedDataset(self._df,
                                    df,
                                    s3_pdb_path=self.s3_pdb_path,
                                    data_dir=self.data_dir,
                                    edges=self.edges
                                    )

            self._train_sampler = None
            self._val_sampler = None
            
            df = self._df_paired[self._df_paired[self._partition_column]=='val']
            assert not df.empty
            self._val_dataset = \
                PairedDataset(self._df,
                                    df,
                                    s3_pdb_path=self.s3_pdb_path,
                                    data_dir=self.data_dir,
                                    edges=self.edges
                                    )
       
        df = self._df_paired[self._df_paired[self._partition_column]=='test']
        self._test_dataset = None
        if not df.empty:
            self._test_dataset = \
                PairedDataset(self._df,
                                    df,
                                    s3_pdb_path=self.s3_pdb_path,
                                    data_dir=self.data_dir,
                                    edges=self.edges
                                    )
 

        if stage == "predict":
            self._predict_dataset = \
                PairedDataset(self._df,
                                    df,
                                    s3_pdb_path=self.s3_pdb_path,
                                    data_dir=self.data_dir,
                                    edges=self.edges
                                    )








