import time
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from data_loading import load_data, preprocess_data, create_sequences
from train import train_model
from evaluate import evaluate_model, compute_metrics, print_classification_report
from model import TLSTMModel, BidirectionalTLSTMModel  # Import the models

# Hyperparameters
dynamic_predictors = [
    'HeartRate', 'RespiratoryRate', 'PulseOx', 'SystolicBP', 'DiastolicBP',
    'Temperature', 'BUN', 'Lactate', 'Platelet', 'Creatinine', 'BiliRubin',
    'WBC', 'CReactiveProtein', 'OxygenFlow']
target = 'Observed_Shock'
early_prediction = 4
batch_size = 4
input_dim = len(dynamic_predictors) + 1  # Add 1 for the TimeStep feature
hidden_dim = 50
output_dim = 1
lr = 0.001
num_epochs = 10
sequence_length = early_prediction * 60
use_bidirectional = True

# Start the timer
start_time = time.time()

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load and preprocess data
df = load_data('cleaned_mimic.csv')
X, y = preprocess_data(df, dynamic_predictors, target, early_prediction)

# Split data into train, validation, and test sets
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.33, random_state=42)

# Train the model
learning_rates = [0.0001, 0.001, 0.01]
batch_sizes = [64, 32, 16, 8]
hidden_dims = [512, 256, 128, 64]
results = train_model(X_train, y_train, dynamic_predictors, learning_rates, batch_sizes, num_epochs, hidden_dims, sequence_length)

# Initialize the model
if use_bidirectional:
    model = BidirectionalTLSTMModel(input_dim, hidden_dim, output_dim).to(device)
else:
    model = TLSTMModel(input_dim, hidden_dim, output_dim).to(device)

# Evaluate the model
X_test_seq, y_test_seq = create_sequences(X_test.values, y_test.values, sequence_length)
X_test_tensor = torch.tensor(X_test_seq).float().to(device)
y_test_tensor = torch.tensor(y_test_seq).float().to(device)
criterion = torch.nn.BCEWithLogitsLoss()

test_outputs, test_labels = evaluate_model(model, X_test_tensor, y_test_tensor, device)
metrics = compute_metrics(test_outputs, test_labels)
print_classification_report(test_labels.cpu().numpy(), (test_outputs.cpu().numpy() > 0.5).astype(int))

print(f"Test loss: {metrics['loss']:.4f}, accuracy: {metrics['accuracy']:.4f}, AUC: {metrics['auc']:.4f}")

# Calculate the runtime
end_time = time.time()
runtime = end_time - start_time
print(f"Runtime: {runtime:.2f} seconds")
print(f'Using the early prediction window as: {early_prediction} hours')
print(f'Using {"bidirectional " if use_bidirectional else ""}T-LSTM')
