import time
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from data_loading import load_data, preprocess_data, prepare_data
from train import train_model
from evaluate import evaluate_model, compute_metrics, print_classification_report
from model import T_VRNN

# Hyperparameters
dynamic_predictors = [
    'HeartRate', 'RespiratoryRate', 'PulseOx', 'SystolicBP', 'DiastolicBP',
    'Temperature', 'BUN', 'Lactate', 'Platelet', 'Creatinine', 'BiliRubin',
    'WBC', 'CReactiveProtein', 'OxygenFlow']
target = 'Observed_Shock'
early_prediction = 36
batch_size = 64
input_dim = len(dynamic_predictors)
hidden_dim = 256
latent_dim = 32
output_dim = 1
lr = 0.001
num_epochs = 28
sequence_length = early_prediction * 60
num_time_indices = 1
time_embedding_dim = 10

# Start the timer
start_time = time.time()

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load and preprocess data
df = load_data('cleaned_mimic.csv')
X, y, times = preprocess_data(df, dynamic_predictors, target, early_prediction)

# Split data into train, validation, and test sets
X_train_val, X_test, y_train_val, y_test, times_train_val, times_test = train_test_split(X, y, times, test_size=0.33, random_state=42)
X_train, X_val, y_train, y_val, times_train, times_val = train_test_split(X_train_val, y_train_val, times_train_val, test_size=0.33, random_state=42)

# Train the model
learning_rates = [0.0001, 0.001, 0.01]
batch_sizes = [64, 32,16, 8]
hidden_dims = [512,256, 128, 64]
latent_dims = [32, 64,128]
results = train_model(X_train, y_train, times_train, dynamic_predictors, learning_rates, batch_sizes, num_epochs, hidden_dims, latent_dims, sequence_length, num_time_indices, time_embedding_dim)

# Initialize the model
model = T_VRNN(input_dim, hidden_dim, latent_dim, output_dim, num_time_indices, time_embedding_dim).to(device)

# Evaluate the model
test_loader = prepare_data(X_test, y_test, times_test, batch_size, sequence_length)
test_outputs, test_labels = evaluate_model(model, test_loader, device)
metrics = compute_metrics(test_outputs, test_labels)
print_classification_report(test_labels, (test_outputs > 0.5).astype(int))

print(f"Test loss: {metrics['loss']:.4f}, accuracy: {metrics['accuracy']:.4f}, AUC: {metrics['auc']:.4f}")

# Calculate the runtime
end_time = time.time()
runtime = end_time - start_time
print(f"Runtime: {runtime:.2f} seconds")
print(f'Using the early prediction window as: {early_prediction} hours')
print(f'Using T-VRNN')
