import torch
import numpy as np
from torch import Tensor
from torch.utils.data import DataLoader



class Dataset(torch.utils.data.Dataset):
    def __init__(self, X, y=None):
        super().__init__()
        self.n_inp = 1 if y is None else 2
        self.X = X
        self.y = y
        
    def __len__(self): return len(self.X)

    def __getitem__(self, idx): 
        if self.n_inp > 1:
            return Tensor(self.X[idx]), Tensor(self.y[idx])
        else:
            return Tensor(self.X[idx])
        
    def new(self, X, y=None):
        "X is a numpy array"
        return type(self)(X, y)


class DataLoader(torch.utils.data.DataLoader):
    # def __init__(self) -> None:
    #     super().__init__()

    def new(self, dataset, batch_size, **kwargs):        
        return type(self)(dataset, batch_size, **kwargs)



# def collate(idxs, ds): 
#     xb, yb = zip(*[ds[i] for i in idxs])
#     return torch.stack(xb),torch.stack(yb)


class DataLoaders:
    def __init__(self, *dls): 
        self.loaders = list(dls)
        self.train = self.loaders[0]
        self.valid = self.loaders[1] if len(self.loaders)>1 else None
        self.batch_size = self.train.batch_size

    def __getitem__(self, idx):
        return self.loaders[idx]

    def __len__(self): return len(self.loaders)

    def add_dl(self, test_data, batch_size=None, **kwargs):
        # check of test_data is already a DataLoader
        from ray.train.torch import _WrappedDataLoader
        if isinstance(test_data, DataLoader) or isinstance(test_data, _WrappedDataLoader): 
            return test_data

        # get batch_size if not defined
        if batch_size is None: batch_size=self.batch_size        
        # check if test_data is Dataset, if not, wrap Dataset
        if not isinstance(test_data, Dataset):
            test_data = self.train.dataset.new(test_data)        
        
        # create a new DataLoader from Dataset 
        test_data = self.train.new(test_data, batch_size, **kwargs)
        return test_data

    


def get_ts_dls(X_train, y_train, X_valid, y_valid, batch_size, num_workers):
    """_summary_

    Args:
        X_train (_type_): array of size n_samples x n_feas x seq_len
        y_train (_type_): array of size n_samples x out_feas
        X_valid (_type_): _description_
        y_valid (_type_): _description_
        batch_size (_type_): _description_
        num_workers (_type_): _description_

    Returns:
        _type_: _description_
    """

    train_ds = Dataset(X_train, y_train)
    valid_ds = Dataset(X_valid, y_valid)
    train_dl = DataLoader(train_ds, batch_size=batch_size, 
                            num_workers=num_workers,
                            shuffle=True)
    valid_dl = DataLoader(valid_ds, batch_size=batch_size, 
                            num_workers=num_workers,
                            shuffle=True)
    dls = DataLoaders(train_dl, valid_dl)

    dls.vars, dls.len = X_train.shape[1], X_train.shape[2]
    dls.c = y_train.shape[1]

    return dls


def get_ts_udls(X_train, X_valid, batch_size, num_workers):
    """_summary_

    Args:
        X_train (_type_): array of size n_samples x n_feas x seq_len
        X_valid (_type_): _description_
        batch_size (_type_): _description_
        num_workers (_type_): _description_

    Returns:
        _type_: _description_
    """
    train_ds = Dataset(X_train, None)
    valid_ds = Dataset(X_valid, None)
    train_dl = DataLoader(train_ds, batch_size=batch_size, 
                            num_workers=num_workers,
                            shuffle=True)
    valid_dl = DataLoader(valid_ds, batch_size=batch_size, 
                            num_workers=num_workers,
                            shuffle=True)
    udls = DataLoaders(train_dl, valid_dl)

    udls.vars, udls.len = X_train.shape[1], X_train.shape[2]
    udls.c = udls.vars
    
    return udls