import pandas as pd
import numpy as np
import math
import os

class SplitTrainTest:
    def __init__(self, 
                 path_features: str,
                 path_raw: str = "data/TSB_128/",
                 datasets_names: list = ['KDD21'], 
                 train_ratio: int = 0.7, 
                 seed: int = 42):
        self.path_raw = path_raw
        self.df_features = pd.read_csv(path_features)
        self.datasets_names = datasets_names
        self.train_ratio = train_ratio
        self.seed = seed
    
    def create_splits_raw(self, read_from_file=None):
        """Creates the splits of a single dataset to train, val, test subsets.
        This is done either randomly, or with a seed, or read the split from a
        file. Please see such files (the ones we used for our experiments) in 
        the directory "experiments/supervised_splits" or 
        "experiments/unsupervised_splits".

        Note: The test set will be created only when reading the splits
            from a file, otherwise only the train, val set are generated.
            The train, val subsets share the same datasets/domains. 
            The test sets that we used in the unsupervised experiments 
            do not (thus the supervised, unsupervised notation).

        :param path_data: path to the initial dataset to be split
        :param split_per: the percentage in which to create the splits
            (skipped when read_from_file)
        :param seed: the seed to use to create the 'random' splits
            (we strongly advise you to use small numbers)
        :param read_from_file: file to read fixed splits from

        :return train_set: list of strings of time series file names
        :return val_set: list of strings of time series file names
        :return test_set: list of strings of time series file names

        Warning : this function has been modified from the original MSAD pipeline
        """
        train_set = []
        val_set = []
        test_set = []
        
        # Set seed if provided
        if self.seed: 
            np.random.seed(self.seed)

        # Read splits from file if provided
        # NOTE : If test set, split the train into train + val
        if read_from_file is not None:
            df = pd.read_csv(read_from_file, index_col=0)
            subsets = list(df.index)
            
            if 'train_set' in subsets and 'val_set' in subsets:
                train_set = [x for x in df.loc['train_set'].tolist() if not isinstance(x, float) or not math.isnan(x)] # list of .csv files containing windows of the time series
                val_set = [x for x in df.loc['val_set'].tolist() if not isinstance(x, float) or not math.isnan(x)]

                return train_set, val_set, test_set
            elif 'train_set' in subsets and 'test_set' in subsets:
                train_test_set = [x for x in df.loc['train_set'].tolist() if not isinstance(x, float) or not math.isnan(x)]
                val_set = [x for x in df.loc['test_set'].tolist() if not isinstance(x, float) or not math.isnan(x)]
                
                datasets = list(set([x.split('/')[0] for x in train_test_set]))
                datasets.sort()

            else:
                raise ValueError('Did not expect this type of file.')
        else:
            datasets = [x for x in os.listdir(self.path_raw) if os.path.isdir(os.path.join(self.path_raw, x))]
            datasets.sort()

        # NOTE : If the path is not a directory, then it is a file, we have to go one level up
        if not os.path.isdir(self.path_raw): 
            self.path_raw = '/'.join(self.path_raw.split('/')[:-1])
        
        # Random split of train & val sets
        for dataset in datasets:
            # Read file names
            fnames = os.listdir(os.path.join(self.path_raw, dataset))

            # Decide on the size of each subset
            n_timeseries = len(fnames)
            train_split = math.ceil(n_timeseries * self.train_ratio)

            # Select random files for each subset
            train_idx = np.random.choice(
                np.arange(n_timeseries), 
                size=train_split, 
                replace=False
            )

            test_idx = np.asarray([x for x in range(n_timeseries) if x not in train_idx])

            # Replace indexes with file names
            train_set.extend([os.path.join(dataset, fnames[x]) for x in train_idx])
            test_set.extend([os.path.join(dataset, fnames[x]) for x in test_idx])
        
        return train_set, val_set, test_set
    

if __name__ == "main" :

    pass
