import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

def load_data(file_path):
    df = pd.read_csv(file_path)
    return df


def preprocess_data(df, dynamic_predictors, target, early_prediction):
    df_cases = df[df[target] == 1]
    df_controls = df[df[target] == 0]

    df_cases_filtered = df_cases.groupby('VisitIdentifier').apply(
        lambda x: x[(x['MinutesFromArrival'].iloc[-1] - x['MinutesFromArrival']) >= early_prediction * 60])
    df_controls_filtered = df_controls.groupby('VisitIdentifier').apply(
        lambda x: x[(x['MinutesFromArrival'].iloc[-1] - x['MinutesFromArrival']) >= early_prediction * 60])

    df_filtered = pd.concat([df_cases_filtered, df_controls_filtered])

    df_filtered.set_index('VisitIdentifier', inplace=True)
    df_filtered.sort_values(by=['VisitIdentifier', 'MinutesFromArrival'], inplace=True)
    df_filtered['label'] = np.where(df_filtered[target] == 1, 1, 0)

    X = df_filtered[dynamic_predictors]
    y = df_filtered['label']
    times = df_filtered['MinutesFromArrival']

    return X, y, times


def create_sequences(input_data, labels, times, sequence_length):
    data_seq = []
    label_seq = []
    time_seq = []
    for i in range(len(input_data) - sequence_length):
        data_seq.append(input_data[i:i + sequence_length])
        label_seq.append(labels[i + sequence_length - 1])
        time_seq.append(times[i:i + sequence_length])
    return np.array(data_seq), np.array(label_seq), np.array(time_seq)


def prepare_data(X, y, times, batch_size=32, sequence_length=240):
    X_seq, y_seq, times_seq = create_sequences(X.values, y.values, times.values, sequence_length)
    X_tensor = torch.tensor(X_seq).float()  # Ensure float
    y_tensor = torch.tensor(y_seq).float()  # Ensure float
    times_tensor = torch.tensor(times_seq).long()  # Ensure long for time embeddings
    dataset = TensorDataset(X_tensor, y_tensor, times_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader
