import os
import glob

import torch
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from torch.utils.data import DataLoader

from .pipeline_preprocessing import ApplyThreshold, _ConcatDataFrames, _SeparateDataFrames, CreateConcatDataset
from .custom_datasets import sequential_dataset, unsupervised_sequential_dataset
from .reproducibility_utils import seed_worker

def load_data(
    train_path,
    train_ratio,
    batchsize,
    columns_to_standardize,
    columns_to_drop,
    sequence_length,
    threshold,
    threshold_value,
    threshold_column,
    target = None,
    scaling = True,
    seed = 42,
    ):
    """
    Load and preprocess data from csv files.

    Parameters:
    train_path (str): Path to the training data. folder that contains (multiple) csv's.
    train_ratio (float): Ratio of the training data to use.
    shuffle (bool): Whether to shuffle the data for the train-val split.
    columns_to_standardize (list): List of features to standardize.
    columns_to_drop (list): List of features to drop.
    sequence_length (int): Length of the sequences.
    target (str): The target column.
    threshold (bool, optional): Whether to apply a threshold.
    treshold_value (float, optional): The threshold value.
    treshold_column (str, optional): The column to apply the threshold to.
    scaling (bool, optional): Whether to apply scaling. Defaults to True.
    seed (int, optional): Random seed. Defaults to 42.

    Returns:
    train_dataloader (DataLoader): DataLoader for the training data.
    val_dataloader (DataLoader): DataLoader for the validation data.
    """

    g = torch.Generator()
    g.manual_seed(seed)

    dfs = []

    for filename in sorted(glob.glob(os.path.join(train_path, '*.csv'))):
        print(filename)
        dfs.append(pd.read_csv(filename))

    ct = ColumnTransformer([("stand", StandardScaler(), columns_to_standardize)],
                       remainder="passthrough",
                       verbose_feature_names_out=False)

    pipeline = Pipeline([('threshold', ApplyThreshold(threshold=threshold_value, by=threshold_column, seq_length=sequence_length)), 
                        ('concat', _ConcatDataFrames()), 
                        ("stand", ct.set_output(transform="pandas")),
                        ('separate', _SeparateDataFrames())
                        ])
    
    # check if data for ssl or fine-tuning should be loaded
    if target == None:
        # ssl
        pipeline.steps.append(['concat dataset',CreateConcatDataset(unsupervised_sequential_dataset, seq_length=sequence_length, columns_to_drop=columns_to_drop)])
    else:
        # supervised
        pipeline.steps.append(['concat dataset',CreateConcatDataset(sequential_dataset, target=target, seq_length=sequence_length, columns_to_drop=columns_to_drop)])

    # check if a threshold should be applied
    if threshold == False:
         pipeline.steps = [step for step in pipeline.steps if step[0] not in ['threshold']]

    # check if scaling should be applied
    if scaling == False:
        pipeline.steps = [step for step in pipeline.steps if step[0] not in ['concat', 'stand', 'separate']]
    
    # run pipeline, split dataset into train and validation
    train_dataset = pipeline.fit_transform(dfs)

    if scaling == False:
        mean_dict = {}
        std_dict = {}
    else:
        mean_dict = {}
        std_dict = {}
        mean = ct.transformers_[0][1].mean_.tolist()
        std = ct.transformers_[0][1].scale_.tolist()
        for idx, col in enumerate(columns_to_standardize):
            mean_dict[col] = mean[idx]
            std_dict[col] = std[idx]
            
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_ratio, 1-train_ratio], generator=g) ##comment out
    train_dataloader = DataLoader(train_dataset, batch_size=batchsize, shuffle=True, worker_init_fn=seed_worker, generator=g)
    val_dataloader = DataLoader(val_dataset, batch_size=batchsize, shuffle=False, worker_init_fn=seed_worker, generator=g)

    return train_dataloader, val_dataloader, mean_dict, std_dict


def load_test_data(
    test_path,
    batchsize,
    columns_to_drop,
    sequence_length,
    target = None,
    threshold = 0.0,
    threshold_column = "vxCG",
    seed = 42,
    return_dfs = False
    ):

    g = torch.Generator()
    g.manual_seed(seed)
    
    test_dfs = []
    for filename in sorted(glob.glob(os.path.join(test_path, '*.csv'))):
        print(filename)
        df_temp = pd.read_csv(filename)

        test_dfs.append(df_temp)

    # apply threshold
    threshold = ApplyThreshold(threshold=threshold, by=threshold_column, seq_length=sequence_length)

    test_dfs = threshold.transform(test_dfs)
    if not return_dfs:
        concat_dataset = CreateConcatDataset(sequential_dataset, target=target, seq_length=sequence_length, columns_to_drop=columns_to_drop)

        test_dataset = concat_dataset.transform(test_dfs)

        test_dataloader = DataLoader(test_dataset, batch_size=batchsize, shuffle=False, worker_init_fn=seed_worker, generator=g)

        return test_dataloader
    else:
        return test_dfs