import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import numpy as np
import torch
from src import utils
from src.datamodules import register_datamodule
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from src.datamodules.datasets.gym import GymDataset
from src.datamodules.datasets.data_utils import Alphabet, MaxTokensBatchSampler
from src.utils.config import compose_config as Cfg, merge_config

from omegaconf import DictConfig
log = utils.get_logger(__name__)

@register_datamodule('gym')
class GymDataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str = './data/SKEMPI_v2_cache',
        max_length: int = 1024,
        atoms: List[str] = ('N', 'CA', 'C', 'O'),
        alphabet=None,
        batch_size: int = 1024,
        max_tokens: int = 6000,
        sort: bool = False,
        num_workers: int = 0,
        pin_memory: bool = False,
        train_split: str = 'train',
        valid_split: str = 'valid',
        test_split: str = 'test',
        num_cvfolds: int = 3,
        cvfold_index: int = 1,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        self.save_hyperparameters(logger=False)
        self.alphabet = None
        # print(config, max_length)
        self.train_data: Optional[Dataset] = None
        self.valid_data: Optional[Dataset] = None
        self.test_data: Optional[Dataset] = None
    def setup(self, stage: Optional[str] = None):
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.

        This method is called by lightning when doing `trainer.fit()` and `trainer.test()`,
        so be careful not to execute the random split twice! The `stage` can be used to
        differentiate whether it's called before trainer.fit()` or `trainer.test()`.
        """
        # load datasets only if they're not loaded already
        if stage == 'fit':
            train = GymDataset(
                self.hparams['data_dir'],
                split='train',
                num_cvfolds=self.hparams['num_cvfolds'],
                cvfold_index=self.hparams['cvfold_index']
            )
            valid = GymDataset(
                self.hparams['data_dir'],
                split='val',
                num_cvfolds=self.hparams['num_cvfolds'],
                cvfold_index=self.hparams['cvfold_index']
            )
            self.train_dataset = train
            self.valid_dataset = valid
        elif stage == 'test' or stage == 'predict':
            test = GymDataset(
                self.hparams['data_dir'],
                split='val',
                num_cvfolds=self.hparams['num_cvfolds'],
                cvfold_index=self.hparams['cvfold_index']
            )
            self.test_dataset = test
        else:
            raise ValueError(f"Invalid stage: {stage}.")

        self.alphabet = Alphabet(**self.hparams['alphabet'])

        self.collate_batch = self.alphabet.featurizer

    def _build_batch_sampler(self, dataset, max_tokens, shuffle=False, distributed=True):
        is_distributed = distributed and torch.distributed.is_initialized()
        batch_sampler = MaxTokensBatchSampler(
            dataset=dataset,
            shuffle=shuffle,
            distributed=is_distributed,
            batch_size=self.hparams['batch_size'],
            max_tokens=max_tokens,
            sort=self.hparams['sort'],
            drop_last=False,
            sort_key=lambda i: len(dataset[i]['aa'])
        )
        return batch_sampler

    def train_dataloader(self):
        if not hasattr(self, 'train_batch_sampler'):
            self.train_batch_sampler = self._build_batch_sampler(
                self.train_dataset,
                max_tokens=self.hparams['max_tokens'],
                shuffle=True
            )
        return DataLoader(
            dataset=self.train_dataset,
            batch_sampler=self.train_batch_sampler,
            num_workers=self.hparams['num_workers'],
            pin_memory=self.hparams['pin_memory'],
            collate_fn=self.collate_batch
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.valid_dataset,
            batch_sampler=self._build_batch_sampler(
                self.valid_dataset, max_tokens=self.hparams['max_tokens'], distributed=False),
            num_workers=self.hparams['num_workers'],
            pin_memory=self.hparams['pin_memory'],
            collate_fn=self.collate_batch
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_dataset,
            batch_sampler=self._build_batch_sampler(
                self.test_dataset, max_tokens=self.hparams['max_tokens'], distributed=False),
            num_workers=self.hparams['num_workers'],
            pin_memory=self.hparams['pin_memory'],
            collate_fn=self.collate_batch
        )