"""
Data module for multitask regression data.

Author:
Date: October 29, 2023
"""
import os
from typing import Optional, Tuple, Dict

import torch
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset, random_split, Subset

from krt import KRT_PATH


class ShuffleDataLoaderWrapper:

    def __init__(self, inner_dl: DataLoader):
        """Constructor.

        Args:
            inner_dl: Inner data loader to use.
        """
        self.inner_dl = inner_dl
        self.iter_dl = iter(inner_dl)

    def __len__(self):
        return len(self.inner_dl)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            next_batch = next(self.iter_dl)
        except StopIteration:
            self.iter_dl = iter(self.inner_dl)
            raise StopIteration
        xi, yi = next_batch
        permutation = torch.randperm(xi.shape[1])
        return xi[:, permutation], yi[:, permutation]

class FeaturizeDataLoaderWrapper:

    def __init__(self, inner_dl: DataLoader, X):
        """Constructor.

        Args:
            inner_dl: Inner data loader to use.
        """
        self.inner_dl = inner_dl
        self.iter_dl = iter(inner_dl)
        self.X = X

    def __len__(self):
        return len(self.inner_dl)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            next_batch = next(self.iter_dl)
        except StopIteration:
            self.iter_dl = iter(self.inner_dl)
            raise StopIteration
        xi, yi = next_batch
        return self.X[xi.squeeze(-1).long()], yi


class MultitaskRegressionData:

    def __init__(
        self,
        x_data: Tensor,
        y_data: Tensor,
        batch_size: int,
        val_proportion: float = 0.05,
        te_data_loader: Optional[DataLoader] = None,
        pin_memory: bool = True,
        num_workers: int = 4,
        seed: int = 0,
        shuffle_idxs: bool = True,
        standardize_features: int = 0,  # how many features to standardize
        **kwargs
    ):
        """Constructor.

        Args:
            x_data: X data for train and validation sets. Has shape
                (num function draws, data per function, x dimension)
            y_data: Y data for train and validation sets. Has shape
                (num function draws, data per function, y dimension).
                All dimensions but the last should match x_data and indices
                should correspond with each other.
            batch_size: Batch size.
            val_proportion: Proportion of data to use for validation.
            te_data_loader: Data loader to use for testing. If provided, should have
                * x_data
                * y_data
                * cumulative joint logprobs
                * marginal logprobs
            te_x_data: x_data to be used for testing only. Shape is in the same form
                as the training set.
            te_y_data: y data to be used for testing only.
            pin_memory: Whether to pin memory in the data loader.
            num_workers: Number of workers to have in the data loader.
            seed: Seed for the random splitting.
            shuffle_idxs: Whether to shuffle the indices of the batch. Because we
                partition the dataset into context and target randomly, this will
                ensure that we use different parts of the data for condition and
                target prediction.
        """
        self.dim_x = x_data.shape[-1]
        self.dim_y = y_data.shape[-1]
        # Split the data and make dataloaders.
        self.num_val = int(len(x_data) * val_proportion)
        self.num_tr = len(x_data) - self.num_val
        tr_dataset, val_dataset = random_split(
            TensorDataset(x_data, y_data),
            [self.num_tr, self.num_val],
            generator=torch.Generator().manual_seed(seed),
        )
        if standardize_features > 0:
            tr_dataset, val_dataset, self.test_data = self.standardize_datasets(tr_dataset, val_dataset, te_data_loader,
                                                                feat_slice=slice(-standardize_features, None))
        self.train_data = DataLoader(
            tr_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=pin_memory,
        )
        if shuffle_idxs:
            self.train_data = ShuffleDataLoaderWrapper(self.train_data)
        self.val_data = DataLoader(
            val_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=False,
            drop_last=False,
            pin_memory=pin_memory,
        )
        self.batches_per_epoch = len(self.train_data)
        # Possibly account for test data.
        if te_data_loader is None:
            self.num_te = 0
        else:
            self.num_te = te_data_loader.dataset.tensors[0].shape[0]
        self.L = x_data.shape[1]

    @staticmethod
    def standardize_datasets(
        tr_subset: Subset,
        val_subset: Subset,
        test_loader: Optional[DataLoader] = None,      # <── NEW
        *,                         # keyword-only below
        feat_slice: slice = slice(2, 4),               # last two x-features
        eps: float = 1e-8
    ) -> Tuple[TensorDataset, TensorDataset, Optional[DataLoader]]:
        """
        Standardize selected x-features and the y-target on train/val splits and,
        if provided, a test DataLoader.  Statistics are computed **only** from the
        training subset.

        Returns
        -------
        tr_std, val_std : TensorDataset
        test_loader_std : DataLoader | None
            A new DataLoader whose dataset is standardized, keeping the original
            batch_size / num_workers / etc.  If `test_loader` was None the third
            return value is None.
        """

        # ---------- compute (μ, σ) from training ----------------------------------
        x_train = torch.cat([x for x, _ in tr_subset], dim=0)   # (B_tr*N, 4)
        y_train = torch.cat([y for _, y in tr_subset], dim=0)   # (B_tr*N, 1)

        mean_x = x_train[:, feat_slice].mean(0)
        std_x  = x_train[:, feat_slice].std(0) + eps
        mean_y = y_train.mean()
        std_y  = y_train.std() + eps

        # ---------- helper for any Dataset-like object ----------------------------
        def _transform_dataset(ds) -> TensorDataset:
            xs, ys = [], []
            for x, y in ds:
                x = x.clone()
                x[:, feat_slice] = (x[:, feat_slice] - mean_x) / std_x
                y = (y - mean_y) / std_y
                xs.append(x)
                ys.append(y)
            xs = torch.stack(xs, dim=0)
            ys = torch.stack(ys, dim=0)
            return TensorDataset(xs, ys)

        tr_std  = _transform_dataset(tr_subset)
        val_std = _transform_dataset(val_subset)

        # ---------- test loader (optional) ----------------------------------------
        test_loader_std = None
        if test_loader is not None:
            # Standardize the *dataset* behind the DataLoader
            test_std_ds = _transform_dataset(test_loader.dataset)

            # Re-create a DataLoader that mirrors the original settings
            test_loader_std = DataLoader(
                test_std_ds,
                batch_size   = test_loader.batch_size,
                shuffle      = False,                       # keep deterministic
                num_workers  = test_loader.num_workers,
                pin_memory   = test_loader.pin_memory,
                drop_last    = test_loader.drop_last,
                persistent_workers = getattr(test_loader, "persistent_workers", False),
            )

        return tr_std, val_std, test_loader_std

    @property
    def train_num_batches(self):
        return len(self.train_data)

    @property
    def val_num_batches(self):
        return len(self.val_data)

    @property
    def te_num_batches(self):
        return 0 if self.test_data is None else len(self.test_data)

    @classmethod
    def construct_data_from_dir(
        cls,
        path: str,
        test_normalize: bool,
        te_path: Optional[str] = None,
        standardize_features: int = 0,
        **kwargs
    ):
        """Instantiate the class from a directory.

        Args:
            path: Path to directory containing
                * x_data.pt
                * y_data.pt
            te_path: Path containing
                * te_x_data.pt
                * te_y_data.pt
                * cum_joint_logprob.pt
                * marginal_logprob.pt
        """
        x_data = torch.load(os.path.join(KRT_PATH, path, 'x_data.pt'))
        y_data = torch.load(os.path.join(KRT_PATH, path, 'y_data.pt'))
        if te_path is not None:
            te_x_data = torch.load(os.path.join(KRT_PATH, te_path, 'te_x_data.pt'))
            te_y_data = torch.load(os.path.join(KRT_PATH, te_path, 'te_y_data.pt'))
            if test_normalize:
                cjoint_ll = torch.load(os.path.join(KRT_PATH, te_path,
                                                    'cum_joint_logprob.pt'))
                marginal_ll = torch.load(os.path.join(KRT_PATH, te_path,
                                                    'marginal_logprob.pt'))
            pin_memory = kwargs.get('pin_memory', True)
            if test_normalize:
                te_data = DataLoader(
                    TensorDataset(te_x_data, te_y_data, cjoint_ll, marginal_ll),
                    batch_size=kwargs['batch_size'],
                    shuffle=False,
                    drop_last=False,
                    pin_memory=pin_memory,
                )
            else:
                te_data = DataLoader(
                    TensorDataset(te_x_data, te_y_data),
                    batch_size=kwargs['batch_size'],
                    shuffle=False,
                    drop_last=False,
                    pin_memory=pin_memory,
                )
        else:
            te_data = None
        return cls(
            x_data=x_data,
            y_data=y_data,
            te_data_loader=te_data,
            standardize_features=standardize_features,
            **kwargs
        )


class MultitaskRegressionBioData:

    def __init__(
        self,
        X_data: Tensor,
        x_data: Tensor,
        y_data: Tensor,
        val_X_data: Tensor,
        val_x_data: Tensor,
        val_y_data: Tensor,
        te_X_data: Tensor,
        te_x_data: Tensor,
        te_y_data: Tensor,
        batch_size: int,
        val_proportion: float = 0.05,
        pin_memory: bool = True,
        num_workers: int = 4,
        seed: int = 0,
        shuffle_idxs: bool = True,
        **kwargs
    ):
        """Constructor.

        Args:
            x_data: X data for train and validation sets. Has shape
                (num function draws, data per function, x dimension)
            y_data: Y data for train and validation sets. Has shape
                (num function draws, data per function, y dimension).
                All dimensions but the last should match x_data and indices
                should correspond with each other.
            batch_size: Batch size.
            val_proportion: Proportion of data to use for validation.
            te_data_loader: Data loader to use for testing. If provided, should have
                * x_data
                * y_data
            te_x_data: x_data to be used for testing only. Shape is in the same form
                as the training set.
            te_y_data: y data to be used for testing only.
            pin_memory: Whether to pin memory in the data loader.
            num_workers: Number of workers to have in the data loader.
            seed: Seed for the random splitting.
            shuffle_idxs: Whether to shuffle the indices of the batch. Because we
                partition the dataset into context and target randomly, this will
                ensure that we use different parts of the data for condition and
                target prediction.
        """
        self.dim_x = X_data.shape[-1] 
        self.dim_y = y_data.shape[-1]

        tr_dataset, val_dataset, te_dataset = TensorDataset(x_data, y_data), TensorDataset(val_x_data, val_y_data), TensorDataset(te_x_data, te_y_data)

        self.num_val = len(val_dataset)
        self.num_tr = len(tr_dataset)
        self.num_te = len(te_dataset)

        self.train_data = DataLoader(
            tr_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=pin_memory,
        )
        if shuffle_idxs:
            self.train_data = ShuffleDataLoaderWrapper(self.train_data)
        self.train_data = FeaturizeDataLoaderWrapper(self.train_data, X_data)

        self.val_data = DataLoader(
            val_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=False,
            drop_last=False,
            pin_memory=pin_memory,
        )
        if shuffle_idxs:
            self.val_data = ShuffleDataLoaderWrapper(self.val_data)
        self.val_data = FeaturizeDataLoaderWrapper(self.val_data, val_X_data)

        self.test_data = DataLoader(
            te_dataset,
            batch_size=1,
            num_workers=num_workers,
            shuffle=False,
            drop_last=False,
            pin_memory=pin_memory,
        )
        if shuffle_idxs:
            self.test_data = ShuffleDataLoaderWrapper(self.test_data)
        self.test_data = FeaturizeDataLoaderWrapper(self.test_data, te_X_data)

        self.batches_per_epoch = len(self.train_data)
        self.num_te = 0
        self.L = x_data.shape[1]

    @property
    def train_num_batches(self):
        return len(self.train_data)

    @property
    def val_num_batches(self):
        return len(self.val_data)

    @property
    def te_num_batches(self):
        return 0 if self.test_data is None else len(self.test_data)

    @classmethod
    def construct_data_from_dir(
        cls,
        path: str,
        prefix: str,
        task_index: int,
        prune_until: int,
        te_path: Optional[str] = None,
        **kwargs
    ):
        """Instantiate the class from a directory. Flips the sequence order
        so as to put the sequence in ascending order of similarity to a particular
        ligand, and prunes prune_until ligands that have lowest similarity

        Args:
            path: Path to directory containing
                * x_data.pt
                * y_data.pt
            te_path: Path containing
                * te_x_data.pt
                * te_y_data.pt
        """
        X_data = torch.load(os.path.join(KRT_PATH, path, prefix + 'tr_X.pt'))
        x_data = torch.flip(torch.load(os.path.join(KRT_PATH, path, prefix + 'tr_x_data.pt'))[:-1,:,:], [1])
        y_data = torch.flip(torch.load(os.path.join(KRT_PATH, path, prefix + 'tr_y_data.pt'))[:-1,:,:], [1])

        x_data = x_data[:,prune_until:,:]
        y_data = y_data[:,prune_until:,task_index].unsqueeze(-1)

        val_X_data = torch.load(os.path.join(KRT_PATH, path, prefix + 'val_X.pt'))
        val_x_data = torch.flip(torch.load(os.path.join(KRT_PATH, path, prefix + 'val_x_data.pt'))[:-1,:,:], [1])
        val_y_data = torch.flip(torch.load(os.path.join(KRT_PATH, path, prefix + 'val_y_data.pt'))[:-1,:,:], [1])

        val_x_data = val_x_data[:,prune_until:,:]
        val_y_data = val_y_data[:,prune_until:,task_index].unsqueeze(-1)        

        te_X_data = torch.load(os.path.join(KRT_PATH, path, prefix + 'te_X.pt'))
        te_x_data = torch.flip(torch.load(os.path.join(KRT_PATH, path, prefix + 'te_x_data.pt'))[:-1,:,:], [1])
        te_y_data = torch.flip(torch.load(os.path.join(KRT_PATH, path, prefix + 'te_y_data.pt'))[:-1,:,:], [1])

        te_x_data = te_x_data[:,prune_until:,:]
        te_y_data = te_y_data[:,prune_until:,task_index].unsqueeze(-1)    

        return cls(
            X_data=X_data,
            x_data=x_data,
            y_data=y_data,
            val_X_data=val_X_data,
            val_x_data=val_x_data,
            val_y_data=val_y_data,            
            **kwargs
        )
