import os
import time
import h5py
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader, TensorDataset
import torch
import random
from os.path import basename as opb, splitext as ops


def downsample_unit(group, sampling):
    num_samples = len(group) // sampling
    downsampled = group.groupby(np.arange(len(group)) // sampling).mean()
    return downsampled


def df_all_creator(data_filepath, sampling=100, units_file="File_DevUnits_TestUnits.csv",
                   kmeans=True, operating_condition=0, kmeans_clusters=3):
    """
    Load and preprocess data from an h5 file, returning a downsampled DataFrame.
    """
    start_time = time.process_time()

    with h5py.File(data_filepath, 'r') as hdf:
        data = {key: np.concatenate([np.array(hdf.get(f'{key}_dev')), np.array(hdf.get(f'{key}_test'))], axis=0)
                for key in ['W', 'X_s', 'X_v', 'T', 'Y', 'A']}
        var_names = {key: [name.decode('utf-8') for name in hdf.get(f'{key}_var')] for key in ['W', 'X_s', 'X_v', 'A']}

    df_all = pd.concat([
        pd.DataFrame(data['W'], columns=var_names['W']),
        pd.DataFrame(data['X_s'], columns=var_names['X_s']),
        pd.DataFrame(data['X_v'][:, :2], columns=['T40', 'P30']),
        pd.DataFrame(data['Y'], columns=['RUL']),
        pd.DataFrame(data['A'], columns=var_names['A']).drop(columns=['Fc', 'hs'])
    ], axis=1)

    df_all_smp = df_all.groupby(['unit', 'RUL'], sort=False).apply(lambda x: downsample_unit(x, sampling))

    print(f"Data loaded in {(time.process_time() - start_time) / 60:.2f} minutes")
    print("Sampled Data shape:", df_all_smp.shape)

    file_devtest_df = pd.read_csv(units_file)
    units_index_train = np.fromstring(
        file_devtest_df[file_devtest_df.File == opb(data_filepath)]["Dev Units"].values[0][1:-1],
        dtype=float, sep=' ').tolist()
    units_index_test = np.fromstring(
        file_devtest_df[file_devtest_df.File == opb(data_filepath)]["Test Units"].values[0][1:-1],
        dtype=float, sep=' ').tolist()

    df_train = df_all_smp[df_all_smp['unit'].isin(units_index_train)].reset_index(drop=True)
    df_test = df_all_smp[df_all_smp['unit'].isin(units_index_test)].reset_index(drop=True)

    cols_normalize = df_train.columns.difference(['RUL', 'unit', 'cycle'])
    min_max_scaler = MinMaxScaler((0, 1))
    df_train[cols_normalize] = min_max_scaler.fit_transform(df_train[cols_normalize])
    df_test[cols_normalize] = min_max_scaler.transform(df_test[cols_normalize])

    if kmeans:
        estimator = KMeans(n_clusters=kmeans_clusters, random_state=0)
        estimator.fit(df_train[['alt', 'Mach', 'TRA', 'T2']].values)

        def filter_condition(df, estimator, operating_condition):
            filtered_df = pd.DataFrame()
            for unit in df['unit'].unique():
                unit_df = df[df['unit'] == unit].reset_index(drop=True)
                unit_df['operating_condition'] = estimator.predict(unit_df[['alt', 'Mach', 'TRA', 'T2']].values)
                filtered_df = pd.concat([filtered_df,
                                          unit_df[unit_df['operating_condition'] == operating_condition]])
            return filtered_df

        df_train = filter_condition(df_train, estimator, operating_condition)
        df_test = filter_condition(df_test, estimator, operating_condition)

    return df_train, df_test


def generate_sequences(df, sequence_length, stride, sequence_cols, label_col='RUL'):
    """
    Generate sample and label sequences using sliding window approach.
    """
    samples, labels = [], []
    for unit_id in df['unit'].unique():
        unit_data = df[df['unit'] == unit_id]
        data_matrix = unit_data[sequence_cols].values
        label_matrix = unit_data[label_col].values

        num_samples = (len(data_matrix) - sequence_length) // stride + 1
        for i in range(num_samples):
            start = i * stride
            end = start + sequence_length
            samples.append(data_matrix[start:end, :])
            labels.append(label_matrix[end - 1])

    samples_array = np.array(samples, dtype=np.float32)
    labels_array = np.array(labels, dtype=np.float32).reshape(-1, 1, 1)
    return samples_array, labels_array


class InputGen:
    def __init__(self, df_train, df_test, cols_normalize, sequence_length, stride, unit_index):
        self.sequence_length = sequence_length
        self.stride = stride
        self.unit_index = unit_index
        self.cols_normalize = cols_normalize
        self.df_train = df_train
        self.df_test = df_test

    def seq_gen(self):
        if self.unit_index in self.df_train['unit'].unique():
            data = self.df_train
        else:
            data = self.df_test
        return generate_sequences(data[data['unit'] == self.unit_index],
                                  self.sequence_length, self.stride, self.cols_normalize)


def prepare_and_load_dataset(data_dir='NCMAPSS', data_file='N-CMAPSS_DS02-006.h5', sequence_length=50, stride=1,
                             sampling=100, units_file="File_DevUnits_TestUnits.csv", batch_size=256,
                             kmeans=True, kmeans_clusters=3, operating_condition=0):
    seed = 0
    random.seed(seed)
    np.random.seed(seed)

    current_dir = os.path.dirname(os.path.abspath(__file__))
    data_filepath = os.path.join(current_dir, data_dir, data_file)
    sample_dir_path = os.path.join(current_dir, data_dir, 'Samples_whole', ops(opb(data_filepath))[0])

    train_save_path = os.path.join(sample_dir_path, f'train_win{sequence_length}_str{stride}_smp{sampling}_c{kmeans_clusters}_oc{operating_condition}.pt')
    test_save_path = os.path.join(sample_dir_path, f'test_win{sequence_length}_str{stride}_smp{sampling}_c{kmeans_clusters}_oc{operating_condition}.pt')

    if os.path.exists(train_save_path) and os.path.exists(test_save_path):
        train_tensor, train_label = torch.load(train_save_path)
        test_data = torch.load(test_save_path)
    else:
        df_train, df_test = df_all_creator(data_filepath, sampling, units_file, kmeans, operating_condition, kmeans_clusters)

        cols_normalize = df_train.columns.difference(['RUL', 'unit', 'operating_condition', 'cycle']).tolist()

        samples_dict = {'train': {'samples': [], 'labels': []}, 'test': {}}

        for unit_index in df_train['unit'].unique():
            data_gen = InputGen(df_train, df_test, cols_normalize, sequence_length, stride, unit_index)
            sample_array, label_array = data_gen.seq_gen()
            samples_dict['train']['samples'].append(sample_array)
            samples_dict['train']['labels'].append(label_array)

        for unit_index in df_test['unit'].unique():
            data_gen = InputGen(df_train, df_test, cols_normalize, sequence_length, stride, unit_index)
            sample_array, label_array = data_gen.seq_gen()
            samples_dict['test'][unit_index] = {'samples': sample_array, 'labels': label_array}

        samples_dict['train']['samples'] = np.concatenate(samples_dict['train']['samples'])
        samples_dict['train']['labels'] = np.concatenate(samples_dict['train']['labels'])

        train_tensor = torch.from_numpy(samples_dict['train']['samples']).float()
        train_label = torch.from_numpy(samples_dict['train']['labels']).float()
        test_data = {unit: (torch.from_numpy(data['samples']).float(), torch.from_numpy(data['labels']).float())
                     for unit, data in samples_dict['test'].items()}

        os.makedirs(sample_dir_path, exist_ok=True)
        torch.save((train_tensor, train_label), train_save_path)
        torch.save(test_data, test_save_path)

    train_dataset = TensorDataset(train_tensor, train_label)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)

    test_dataloaders = [
        DataLoader(TensorDataset(sample, label), batch_size=batch_size, shuffle=False, drop_last=False)
        for sample, label in test_data.values()
    ]
    return train_dataloader, test_dataloaders
