import time
import torch
import numpy as np
from sklearn.metrics import roc_auc_score, classification_report
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from data_loading import load_data, preprocess_data, prepare_times_tensor
from Right_Align.Transformer.T_transformer import TimeAwareTransformer,BidirectionalTimeAwareTransformer
from train import train_model
#from evaluate import evaluate_model, compute_metrics, print_classification_report


# Hyperparameters
dynamic_predictors = [
    'MinutesFromArrival', 'HeartRate', 'RespiratoryRate', 'PulseOx', 'SystolicBP', 'DiastolicBP',
    'Temperature', 'BUN', 'Lactate', 'Platelet', 'Creatinine', 'BiliRubin',
    'WBC', 'CReactiveProtein', 'OxygenFlow']
target = 'Observed_Shock'
early_prediction = 4
batch_size = 4
input_dim = len(dynamic_predictors) - 1
hidden_dim = 50
output_dim = 1
lr = 0.001
num_epochs = 20
num_folds = 2
num_heads = 5  # Adjusting num_heads to ensure divisibility
num_layers = 2
time_embedding_dim = 5  # Adjusted to ensure divisibility
num_time_indices = 1
use_bidirectional = False

# Ensure input_dim + time_embedding_dim is divisible by num_heads
#assert (input_dim + time_embedding_dim) % num_heads == 0, "input_dim + time_embedding_dim must be divisible by num_heads"

# Start the timer
start_time = time.time()

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load and preprocess data
df = load_data('cleaned_mimic.csv')
X, y, times = preprocess_data(df, dynamic_predictors, target, early_prediction)

# Split data into train, validation, and test sets
X_train_val, X_test, y_train_val, y_test, times_train_val, times_test = train_test_split(X, y, times, test_size=0.33, random_state=42)
X_train, X_val, y_train, y_val, times_train, times_val = train_test_split(X_train_val, y_train_val, times_train_val, test_size=0.33, random_state=42)

# Train the model
learning_rates = [0.0001, 0.001, 0.01]
batch_sizes = [64, 32, 16, 8]
hidden_dims = [512, 256, 128, 64]
results = train_model(X_train, y_train, times_train, dynamic_predictors, learning_rates, batch_sizes, num_epochs, hidden_dims, num_heads, num_layers, time_embedding_dim, num_time_indices, num_folds, device, use_bidirectional)

# Evaluate the model
if use_bidirectional:
    model = BidirectionalTimeAwareTransformer(input_dim, hidden_dim, output_dim, num_heads, num_layers, num_time_indices, time_embedding_dim).to(device)
else:
    model = TimeAwareTransformer(input_dim, hidden_dim, output_dim, num_heads, num_layers, num_time_indices, time_embedding_dim).to(device)

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# model_path = "time_aware_transformer_model_fold_1.pth"
# model.load_state_dict(torch.load(model_path))
# model.eval()

times_tensor_test = prepare_times_tensor(times_test, num_time_indices, device)

test_loader = DataLoader(TensorDataset(torch.tensor(X_test.values.astype(np.float32)).unsqueeze(1), times_tensor_test, torch.tensor(y_test.values.astype(np.float32)).unsqueeze(1)), shuffle=False, batch_size=batch_size)

total_test_loss = 0
all_test_outputs = []
all_y_test = []

with torch.no_grad():
    for X_batch, times_batch, y_batch in test_loader:
        X_batch, times_batch, y_batch = X_batch.to(device), times_batch.to(device), y_batch.to(device)
        test_outputs = model(X_batch, times_batch).squeeze()
        loss = criterion(test_outputs, y_batch.squeeze())
        total_test_loss += loss.item() * X_batch.size(0)
        all_test_outputs.append(test_outputs)
        all_y_test.append(y_batch)

all_test_outputs = torch.cat(all_test_outputs).squeeze()
all_y_test = torch.cat(all_y_test).squeeze()

# Compute test metrics
test_probs = torch.sigmoid(all_test_outputs).cpu().numpy()
test_preds = (all_test_outputs > 0.5).float().cpu().numpy()
y_test_numpy = all_y_test.cpu().numpy()

test_auc = roc_auc_score(y_test_numpy, test_probs)
classification_rep = classification_report(y_test_numpy, test_preds, digits=4)

print(classification_rep)
print(f'Test AUC: {test_auc:.4f}')

# Calculate the runtime
end_time = time.time()
runtime = end_time - start_time
print(f"Using the early_prediction as: {early_prediction}")
print(f"Runtime: {runtime:.2f} seconds")
print(f'Using {"Bidirectional " if use_bidirectional else ""}Time-Aware Transformer')
print(f'MIMIC')
