
from sklearn.model_selection import train_test_split

"""
The input is a list of set of features 
X_sequence: X1, X2, .... XT and set of labels
Y_sequence: y_1, Y_2 .... YT

They form together a list of datasets:
D1, ... DT

We split them at some time t : offline_t to separate the training timsesteps from the testing timesteps

D1, .. Dt | Dt+1,  .... DT

and each individual dataset D is split in train, val and test 

(Dtest, Dval Dtrain)1 , ... (Dtest, Dval Dtrain)t | (Dtest, Dval Dtrain)t+1 , ... (Dtest, Dval Dtrain)T

"""


class SeqDataset():
    def __init__(self, X_sequence: list, Y_sequence: list, offline_t: int, test_split: float, val_split: float, seed: int):
        self.X_sequence = X_sequence
        self.Y_sequence = Y_sequence
        self.offline_t = offline_t
        self.test_split = test_split
        self.val_split = val_split
        self.seed = seed
        self.T = len(X_sequence)

        self.create_offline_online_datasets()

    def create_offline_online_datasets(self):
        X_offline = self.X_sequence[:self.offline_t]
        X_online = self.X_sequence[self.offline_t:]
        y_offline = self.Y_sequence[:self.offline_t]
        y_online = self.Y_sequence[self.offline_t:]
        self.offline = {}
        
        for t, (_X, _y) in enumerate(zip(X_offline, y_offline)):
            X_train, X_test, y_train, y_test = train_test_split(
                _X, _y, test_size=self.test_split, random_state=self.seed
            )
            X_train, X_val, y_train, y_val = train_test_split(
                X_train, y_train, test_size=self.val_split, random_state=self.seed
            )
            self.offline[t] = {
                "X_train": X_train,
                "y_train": y_train,
                "X_val": X_val,
                "y_val": y_val,
                "X_test": X_test,
                "y_test": y_test,
            }

        self.online = {}
        for t, (_X, _y) in enumerate(zip(X_online, y_online)):
            t = t + self.offline_t
            X_train, X_test, y_train, y_test = train_test_split(
                _X, _y, test_size=self.test_split, random_state=0
            )
            X_train, X_val, y_train, y_val = train_test_split(
                X_train, y_train, test_size=self.val_split, random_state=self.seed
            )
            self.online[t] = {
                "X_train": X_train,
                "y_train": y_train,
                "X_val": X_val,
                "y_val": y_val,
                "X_test": X_test,
                "y_test": y_test,
            }

    def get_interesting_datasets(self):
        
        D_query_0 = self.get_X_Y(0, split='test')
        D_0 = self.get_X_Y(0, split='train')
        D_query_t = self.get_X_Y(self.offline_t, split='test')
        D_t = self.get_X_Y(self.offline_t, split='test')
        D_query_T = self.get_X_Y(self.T-1, split='test')
        D_T = self.get_X_Y(self.T-1, split='test')
        return {'D_query_0': D_query_0,
                                            'D_0': D_0,
                                            'D_query_t': D_query_t,
                                            'D_t': D_t,
                                            'D_query_T': D_query_T,
                                            'D_T': D_T}

    def get_X_Y(self, t:int, split:str):
        if t < self.offline_t:
            X_t = self.offline[t]['X_'+split]
            Y_t = self.offline[t]['y_'+split]
        else:
            X_t = self.online[t]['X_'+split]
            Y_t = self.online[t]['y_'+split] 
        return X_t, Y_t