import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
from model import T_VRNN,Bi_T_VRNN
from data_loading import prepare_data, create_sequences
from evaluate import evaluate_model, compute_metrics

def train_model(X, y, times, dynamic_predictors, learning_rates, batch_sizes, epochs, hidden_dims, latent_dims, sequence_length, num_time_indices, time_embedding_dim, inner_kf_splits=2, model_type ='T_VRNN'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    outer_kf = KFold(n_splits=2, shuffle=True, random_state=42)
    inner_kf = KFold(n_splits=inner_kf_splits, shuffle=True, random_state=42)

    outer_test_auc = []
    outer_test_f1 = []
    fold_accuracies, fold_precisions, fold_recalls = [], [], []

    for fold, (train_index, test_index) in enumerate(outer_kf.split(X)):
        X_train_outer, X_test_outer = X.iloc[train_index], X.iloc[test_index]
        y_train_outer, y_test_outer = y.iloc[train_index], y.iloc[test_index]
        times_train_outer, times_test_outer = times.iloc[train_index], times.iloc[test_index]

        best_hyperparams = None
        best_val_auc = 0

        for lr in learning_rates:
            for batch_size in batch_sizes:
                for hidden_dim in hidden_dims:
                    for latent_dim in latent_dims:
                        avg_val_auc = 0
                        for inner_fold, (train_index_inner, val_index_inner) in enumerate(inner_kf.split(X_train_outer)):
                            X_train_inner, X_val_inner = X_train_outer.iloc[train_index_inner], X_train_outer.iloc[val_index_inner]
                            y_train_inner, y_val_inner = y_train_outer.iloc[train_index_inner], y_train_outer.iloc[val_index_inner]
                            times_train_inner, times_val_inner = times_train_outer.iloc[train_index_inner], times_train_outer.iloc[val_index_inner]

                            train_loader = prepare_data(X_train_inner, y_train_inner, times_train_inner, batch_size, sequence_length)
                            val_loader = prepare_data(X_val_inner, y_val_inner, times_val_inner, batch_size, sequence_length)
                            if model_type == 'Bi_T_VRNN':
                                model = Bi_T_VRNN(input_dim=len(dynamic_predictors), hidden_dim=hidden_dim,
                                                  latent_dim=latent_dim, output_dim=1,
                                                  num_time_indices=num_time_indices,
                                                  time_embedding_dim=time_embedding_dim)
                            else:
                                model = T_VRNN(input_dim=len(dynamic_predictors), hidden_dim=hidden_dim,
                                               latent_dim=latent_dim, output_dim=1, num_time_indices=num_time_indices,
                                               time_embedding_dim=time_embedding_dim)
                            model.to(device)
                            criterion = nn.BCEWithLogitsLoss()
                            optimizer = torch.optim.Adam(model.parameters(), lr=lr)

                            best_val_auc_early_stopping = 0
                            patience = 3
                            patience_counter = 0

                            for epoch in range(epochs):
                                model.train()
                                total_loss = 0
                                for inputs, labels, times in train_loader:
                                    inputs, labels, times = inputs.to(device), labels.to(device), times.to(device)
                                    optimizer.zero_grad()
                                    outputs, _, _ = model(inputs, times)
                                    loss = criterion(outputs.view(-1), labels.view(-1))
                                    loss.backward()
                                    optimizer.step()
                                    total_loss += loss.item()

                                model.eval()
                                val_outputs = []
                                val_labels = []
                                with torch.no_grad():
                                    for inputs, labels, times in val_loader:
                                        inputs, labels, times = inputs.to(device), labels.to(device), times.to(device)
                                        outputs, _, _ = model(inputs, times)
                                        val_outputs.append(outputs)
                                        val_labels.append(labels)
                                val_outputs = torch.cat(val_outputs).view(-1).cpu()
                                val_labels = torch.cat(val_labels).view(-1).cpu()
                                val_auc = roc_auc_score(val_labels, val_outputs)
                                avg_val_auc += val_auc / inner_kf.n_splits

                                print(f"Epoch {epoch + 1}/{epochs}, Validation AUC: {val_auc:.4f}")

                                if val_auc > best_val_auc_early_stopping:
                                    best_val_auc_early_stopping = val_auc
                                    patience_counter = 0
                                else:
                                    patience_counter += 1

                                if patience_counter >= patience:
                                    print(f"Stopping early at epoch {epoch + 1}")
                                    break

                        if avg_val_auc > best_val_auc:
                            best_val_auc = avg_val_auc
                            best_hyperparams = {'lr': lr, 'batch_size': batch_size, 'epochs': epochs, 'hidden_dim': hidden_dim, 'latent_dim': latent_dim}

        train_loader = prepare_data(X_train_outer, y_train_outer, times_train_outer, best_hyperparams['batch_size'], sequence_length)
        model = T_VRNN(input_dim=len(dynamic_predictors), hidden_dim=best_hyperparams['hidden_dim'], latent_dim=best_hyperparams['latent_dim'], output_dim=1, num_time_indices=num_time_indices, time_embedding_dim=time_embedding_dim)
        model.to(device)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=best_hyperparams['lr'])

        for epoch in range(best_hyperparams['epochs']):
            model.train()
            for inputs, labels, times in train_loader:
                inputs, labels, times = inputs.to(device), labels.to(device), times.to(device)
                optimizer.zero_grad()
                outputs, _, _ = model(inputs, times)
                loss = criterion(outputs.view(-1), labels.view(-1))
                loss.backward()
                optimizer.step()

        test_loader = prepare_data(X_test_outer, y_test_outer, times_test_outer, best_hyperparams['batch_size'], sequence_length)
        test_outputs, test_labels = evaluate_model(model, test_loader, device)
        metrics = compute_metrics(test_outputs, test_labels)

        fold_accuracies.append(metrics['accuracy'])
        fold_precisions.append(metrics['precision'])
        fold_recalls.append(metrics['recall'])
        outer_test_auc.append(metrics['auc'])
        outer_test_f1.append(metrics['f1'])

        # Save the model for the current fold
        model_path = f"models/t_vrnn_model_fold_{fold + 1}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")

    return {
        'accuracy': (np.mean(fold_accuracies), np.std(fold_accuracies)),
        'precision': (np.mean(fold_precisions), np.std(fold_precisions)),
        'recall': (np.mean(fold_recalls), np.std(fold_recalls)),
        'test_auc': (np.mean(outer_test_auc), np.std(outer_test_auc)),
        'test_f1': (np.mean(outer_test_f1), np.std(outer_test_f1))
    }
