import pandas as pd
import numpy as np
import math
import os

class SplitTrainTest:
    def __init__(self, 
                 path_features: str,
                 path_raw: str = "data/TSB_128/",
                 datasets_names: list = ['KDD21'], 
                 train_ratio: int = 0.7, 
                 seed: int = 42):
        """
        --------------
        Args : 
            - path_features (pq.DataFrame) : ath to the full csv dataset of interest
            - datasets (list of str) : list of datasets names of interest
            - train_ratio : train size over the whole datset selected (for a given frac_group)
            - seed (int): fix the randomness
            - frac_group : the fraction of each dataset considered in the union of test and train 
        --------------
        """
        self.path_raw = path_raw
        self.df_features = pd.read_csv(path_features)
        self.datasets_names = datasets_names
        self.train_ratio = train_ratio
        self.seed = seed


    def create_train_val_feat(self):
        """
        --------------
        Idea : Create a train set and a val set of lines from a dataframe containing all the ts_files as the 1st column.
        --------------
        Note : This does not handle training on one dataset and validating on an other
        -------------- 
        """
        # Extraire le "mot" de chaque ligne (en supposant que la première colonne contient "dataset/")
        self.df_features["dataset"] = self.df_features.iloc[:, 0].apply(lambda x: x.split("/")[0])

        # Initialiser les DataFrames pour train et val
        train_list = []
        val_list = []

        print(self.datasets_names)

        # Grouper par "mot" et diviser en 70% - 30%
        for dataset, group in self.df_features.groupby("dataset"):
            if dataset in self.datasets_names:
                group = group.sample(frac=1, random_state=self.seed)  # Mélanger les lignes du groupe
                split_idx = int(len(group) * self.train_ratio)  # Calcul de l'indice de séparation 70/30

                train_list.append(group.iloc[:split_idx])  # 70% des données
                val_list.append(group.iloc[split_idx:])  # 30% des données

        # Concaténer toutes les parties de train et val
        if len(train_list) == 0:
            raise ValueError("Aucun dataset trouvé dans les noms donnés")
        elif len(train_list) == 1:
            train_df = train_list[0]
            val_df = val_list[0]
        else:
            train_df = pd.concat(train_list)
            val_df = pd.concat(val_list)

        # Supprimer la colonne temporaire "mot"
        train_df = train_df.drop(columns=["dataset"])
        val_df = val_df.drop(columns=["dataset"])

        # Create a folder to store the train and val
        path_save_dir = "data/TSB/features/CATCH22_TSB_clean" + f"_train_{self.train_ratio}" + f"_seed_{self.seed}"

        os.makedirs(path_save_dir, exist_ok=True)  

        path_train = path_save_dir + "/train.csv"
        path_val = path_save_dir + "/val.csv"

        train_df.to_csv(path_train, index=False)
        val_df.to_csv(path_val, index=False)

        print(f"✅ Fichiers du split sauvegardés : train_set ({len(train_df)} lignes) à {path_train} \n et val_set ({len(val_df)} lignes à {path_val}")
        return path_train, path_val
    

    def reproduce_split_feat(self, train_set, val_set):
        """
        --------------
        Idea : Reproduce a split of the data in the train and val set
        --------------
        Args : 
            - train_set : list of strings of time series file names not clean
            - val_set : list of strings of time series file names not clean 
        --------------
        """
        # Create a folder to store the train and val
        path_save_dir = "data/TSB/features/CATCH22_TSB_clean" + f"_train_{self.train_ratio}" + f"_reproduced"

        os.makedirs(path_save_dir, exist_ok=True)  

        path_train = path_save_dir + "/train.csv"
        path_val = path_save_dir + "/val.csv"

        all_ts_clean = self.df_features.iloc[:, 0].tolist()

        # Define lists of train, eval and val sets as the intersection of the given sets and the cleaned dataset
        train_set_clean  = [x for x in train_set if x in all_ts_clean]
        val_set_clean = [x for x in val_set if x in all_ts_clean]
        # val_set = [x for x in val_set if x in all_ts_clean]

        # Define the corresponding feature dataset for each set
        train_df_clean = self.df_features[self.df_features.iloc[:, 0].isin(train_set_clean)]
        val_df_clean = self.df_features[self.df_features.iloc[:, 0].isin(val_set_clean)]
        # val_df = self.df_features[self.df_features.iloc[:, 0].isin(val_set)]

        train_df_clean.to_csv(path_train, index=False)
        val_df_clean.to_csv(path_val, index=False)
        # val_df.to_csv(path_val, index=False)

        print(f"✅ Fichiers du split sauvegardés : train_set ({len(train_df_clean)} lignes) à {path_train} \n et val_set ({len(val_df_clean)} lignes à {path_val}") # \n et val_set ({len(val_df)} lignes à {path_val}")
        return path_train, path_val #, path_val  
    
    

    def create_splits_raw(self, read_from_file=None):
        """Creates the splits of a single dataset to train, val, test subsets.
        This is done either randomly, or with a seed, or read the split from a
        file. Please see such files (the ones we used for our experiments) in 
        the directory "experiments/supervised_splits" or 
        "experiments/unsupervised_splits".

        Note: The test set will be created only when reading the splits
            from a file, otherwise only the train, val set are generated.
            The train, val subsets share the same datasets/domains. 
            The test sets that we used in the unsupervised experiments 
            do not (thus the supervised, unsupervised notation).

        :param path_data: path to the initial dataset to be split
        :param split_per: the percentage in which to create the splits
            (skipped when read_from_file)
        :param seed: the seed to use to create the 'random' splits
            (we strongly advise you to use small numbers)
        :param read_from_file: file to read fixed splits from

        :return train_set: list of strings of time series file names
        :return val_set: list of strings of time series file names
        :return test_set: list of strings of time series file names
        """
        train_set = []
        val_set = []
        test_set = []
        
        # Set seed if provided
        if self.seed: 
            np.random.seed(self.seed)

        # Read splits from file if provided
        # NOTE : If test set, split the train into train + val
        if read_from_file is not None:
            df = pd.read_csv(read_from_file, index_col=0)
            subsets = list(df.index)
            
            if 'train_set' in subsets and 'val_set' in subsets:
                train_set = [x for x in df.loc['train_set'].tolist() if not isinstance(x, float) or not math.isnan(x)] # list of .csv files containing windows of the time series
                val_set = [x for x in df.loc['val_set'].tolist() if not isinstance(x, float) or not math.isnan(x)]

                return train_set, val_set, test_set
            elif 'train_set' in subsets and 'test_set' in subsets:
                # NOTE: the test set in the csv is used as the final evalutation set called "val_set" in my pipeline
                train_test_set = [x for x in df.loc['train_set'].tolist() if not isinstance(x, float) or not math.isnan(x)]
                val_set = [x for x in df.loc['test_set'].tolist() if not isinstance(x, float) or not math.isnan(x)]

                datasets = list(set([x.split('/')[0] for x in train_test_set]))
                datasets.sort()

                # train_set = [x for x in df.loc['train_set'].tolist() if not isinstance(x, float) or not math.isnan(x)]
                # val_set = [x for x in df.loc['test_set'].tolist() if not isinstance(x, float) or not math.isnan(x)]

                # return train_set, val_set, test_set
            else:
                raise ValueError('Did not expect this type of file.')
        else:
            datasets = [x for x in os.listdir(self.path_raw) if os.path.isdir(os.path.join(self.path_raw, x))]
            datasets.sort()

        # NOTE : If the path is not a directory, then it is a file, we have to go one level up
        if not os.path.isdir(self.path_raw): 
            self.path_raw = '/'.join(self.path_raw.split('/')[:-1])
        
        # Random split of train & val sets
        for dataset in datasets:
            # Read file names
            fnames = os.listdir(os.path.join(self.path_raw, dataset))

            # Decide on the size of each subset
            n_timeseries = len(fnames)
            train_split = math.ceil(n_timeseries * self.train_ratio)

            # Select random files for each subset
            train_idx = np.random.choice(
                np.arange(n_timeseries), 
                size=train_split, 
                replace=False
            )

            test_idx = np.asarray([x for x in range(n_timeseries) if x not in train_idx])

            # Replace indexes with file names
            train_set.extend([os.path.join(dataset, fnames[x]) for x in train_idx])
            test_set.extend([os.path.join(dataset, fnames[x]) for x in test_idx])
        
        return train_set, val_set, test_set
