import numpy as np
import torch
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, TensorDataset
from data_loading import prepare_times_tensor
from model import TimeAwareTransformer, BidirectionalTimeAwareTransformer

def prepare_data(X, y, times, batch_size, num_time_indices, device):
    X_tensor = torch.tensor(X.values.astype(np.float32)).unsqueeze(1)
    y_tensor = torch.tensor(y.values.astype(np.float32)).unsqueeze(1)
    times_tensor = prepare_times_tensor(times, num_time_indices, device)
    dataset = TensorDataset(X_tensor, times_tensor, y_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

def train_model(X, y, times, dynamic_predictors, learning_rates, batch_sizes, epochs, hidden_dims, num_heads, num_layers, time_embedding_dim, num_time_indices, num_folds, device, use_bidirectional):
    outer_kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)
    inner_kf = KFold(n_splits=2, shuffle=True, random_state=42)

    outer_test_auc = []
    outer_test_f1 = []
    fold_accuracies, fold_precisions, fold_recalls = []

    for fold, (train_index, test_index) in enumerate(outer_kf.split(X)):
        X_train_outer, X_test_outer = X.iloc[train_index], X.iloc[test_index]
        y_train_outer, y_test_outer = y.iloc[train_index], y.iloc[test_index]
        times_train_outer, times_test_outer = times[train_index], times[test_index]

        best_hyperparams = None
        best_val_auc = 0

        for lr in learning_rates:
            for batch_size in batch_sizes:
                for hidden_dim in hidden_dims:
                    avg_val_auc = 0
                    for inner_fold, (train_index_inner, val_index_inner) in enumerate(inner_kf.split(X_train_outer)):
                        X_train_inner, X_val_inner = X_train_outer.iloc[train_index_inner], X_train_outer.iloc[val_index_inner]
                        y_train_inner, y_val_inner = y_train_outer.iloc[train_index_inner], y_train_outer.iloc[val_index_inner]
                        times_train_inner, times_val_inner = times_train_outer[train_index_inner], times_train_outer[val_index_inner]

                        train_loader = prepare_data(X_train_inner, y_train_inner, times_train_inner, batch_size, num_time_indices, device)
                        val_loader = prepare_data(X_val_inner, y_val_inner, times_val_inner, batch_size, num_time_indices, device)

                        if use_bidirectional:
                            model = BidirectionalTimeAwareTransformer(input_dim=len(dynamic_predictors) - 1, hidden_dim=hidden_dim, output_dim=1, num_heads=num_heads, num_layers=num_layers, num_time_indices=num_time_indices, time_embedding_dim=time_embedding_dim)
                        else:
                            model = TimeAwareTransformer(input_dim=len(dynamic_predictors) - 1, hidden_dim=hidden_dim, output_dim=1, num_heads=num_heads, num_layers=num_layers, num_time_indices=num_time_indices, time_embedding_dim=time_embedding_dim)

                        model.to(device)
                        criterion = torch.nn.BCEWithLogitsLoss()
                        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

                        best_val_auc_early_stopping = 0
                        patience = 3
                        patience_counter = 0

                        for epoch in range(epochs):
                            model.train()
                            total_loss = 0
                            for inputs, times_batch, labels in train_loader:
                                inputs, times_batch, labels = inputs.to(device), times_batch.to(device), labels.to(device)
                                optimizer.zero_grad()
                                outputs = model(inputs, times_batch)
                                loss = criterion(outputs.squeeze(-1), labels.view(-1))
                                loss.backward()
                                optimizer.step()
                                total_loss += loss.item()

                            model.eval()
                            val_outputs = []
                            val_labels = []
                            with torch.no_grad():
                                for inputs, times_batch, labels in val_loader:
                                    inputs, times_batch, labels = inputs.to(device), times_batch.to(device), labels.to(device)
                                    outputs = model(inputs, times_batch)
                                    val_outputs.append(outputs)
                                    val_labels.append(labels)
                            val_outputs = torch.cat(val_outputs).squeeze().cpu()
                            val_labels = torch.cat(val_labels).squeeze().cpu()
                            val_auc = roc_auc_score(val_labels, val_outputs)
                            avg_val_auc += val_auc / inner_kf.n_splits

                            if val_auc > best_val_auc_early_stopping:
                                best_val_auc_early_stopping = val_auc
                                patience_counter = 0
                            else:
                                patience_counter += 1

                            if patience_counter >= patience:
                                break

                    if avg_val_auc > best_val_auc:
                        best_val_auc = avg_val_auc
                        best_hyperparams = {'lr': lr, 'batch_size': batch_size, 'epochs': epochs, 'hidden_dim': hidden_dim}

        train_loader = prepare_data(X_train_outer, y_train_outer, times_train_outer, best_hyperparams['batch_size'], num_time_indices, device)
        if use_bidirectional:
            model = BidirectionalTimeAwareTransformer(input_dim=len(dynamic_predictors) - 1, hidden_dim=best_hyperparams['hidden_dim'], output_dim=1, num_heads=num_heads, num_layers=num_layers, num_time_indices=num_time_indices, time_embedding_dim=time_embedding_dim)
        else:
            model = TimeAwareTransformer(input_dim=len(dynamic_predictors) - 1, hidden_dim=best_hyperparams['hidden_dim'], output_dim=1, num_heads=num_heads, num_layers=num_layers, num_time_indices=num_time_indices, time_embedding_dim=time_embedding_dim)

        model.to(device)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=best_hyperparams['lr'])

        for epoch in range(best_hyperparams['epochs']):
            model.train()
            for inputs, times_batch, labels in train_loader:
                inputs, times_batch, labels = inputs.to(device), times_batch.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs, times_batch)
                loss = criterion(outputs.squeeze(-1), labels.view(-1))
                loss.backward()
                optimizer.step()

        test_loader = prepare_data(X_test_outer, y_test_outer, times_test_outer, best_hyperparams['batch_size'], num_time_indices, device)
        test_outputs, test_labels = evaluate_model(model, test_loader, device)
        metrics = compute_metrics(test_outputs, test_labels)

        fold_accuracies.append(metrics['accuracy'])
        fold_precisions.append(metrics['precision'])
        fold_recalls.append(metrics['recall'])
        outer_test_auc.append(metrics['auc'])
        outer_test_f1.append(metrics['f1'])

        # Save the model for the current fold
        model_path = f"time_aware_transformer_model_fold_{fold + 1}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")

    return {
        'accuracy': (np.mean(fold_accuracies), np.std(fold_accuracies)),
        'precision': (np.mean(fold_precisions), np.std(fold_precisions)),
        'recall': (np.mean(fold_recalls), np.std(fold_recalls)),
        'test_auc': (np.mean(outer_test_auc), np.std(outer_test_auc)),
        'test_f1': (np.mean(outer_test_f1), np.std(outer_test_f1))
    }
