import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
from model import LSTMModel, Bi_LSTMModel  # Import both models
from data_loading import preprocess_data
from evaluate import evaluate_model, compute_metrics

def prepare_data(X, y, batch_size=32):
    X_tensor = torch.tensor(X.values.astype(np.float32)).unsqueeze(1)
    y_tensor = torch.tensor(y.values.astype(np.float32)).unsqueeze(1)
    dataset = TensorDataset(X_tensor, y_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

def train_model(X, y, dynamic_predictors, learning_rates, batch_sizes, epochs, hidden_dims, inner_kf_splits=2, model_type='LSTM'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    outer_kf = KFold(n_splits=2, shuffle=True, random_state=42)
    inner_kf = KFold(n_splits=inner_kf_splits, shuffle=True, random_state=42)

    outer_test_auc = []
    outer_test_f1 = []
    fold_accuracies, fold_precisions, fold_recalls = [], [], []

    for fold, (train_index, test_index) in enumerate(outer_kf.split(X)):
        X_train_outer, X_test_outer = X.iloc[train_index], X.iloc[test_index]
        y_train_outer, y_test_outer = y.iloc[train_index], y.iloc[test_index]

        best_hyperparams = None
        best_val_auc = 0

        for lr in learning_rates:
            for batch_size in batch_sizes:
                for hidden_dim in hidden_dims:
                    avg_val_auc = 0
                    for inner_fold, (train_index_inner, val_index_inner) in enumerate(inner_kf.split(X_train_outer)):
                        X_train_inner, X_val_inner = X_train_outer.iloc[train_index_inner], X_train_outer.iloc[val_index_inner]
                        y_train_inner, y_val_inner = y_train_outer.iloc[train_index_inner], y_train_outer.iloc[val_index_inner]

                        train_loader = prepare_data(X_train_inner, y_train_inner, batch_size)
                        val_loader = prepare_data(X_val_inner, y_val_inner, batch_size)

                        if model_type == 'Bi_LSTM':
                            model = Bi_LSTMModel(input_dim=len(dynamic_predictors), hidden_dim=hidden_dim, output_dim=1)
                        else:
                            model = LSTMModel(input_dim=len(dynamic_predictors), hidden_dim=hidden_dim, output_dim=1)

                        model.to(device)
                        criterion = nn.BCEWithLogitsLoss()
                        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

                        best_val_auc_early_stopping = 0
                        patience = 3
                        patience_counter = 0

                        for epoch in range(epochs):
                            model.train()
                            total_loss = 0
                            for inputs, labels in train_loader:
                                inputs, labels = inputs.to(device), labels.to(device)
                                optimizer.zero_grad()
                                outputs = model(inputs)
                                loss = criterion(outputs.squeeze(-1), labels.view(-1))
                                loss.backward()
                                optimizer.step()
                                total_loss += loss.item()

                            model.eval()
                            val_outputs = []
                            val_labels = []
                            with torch.no_grad():
                                for inputs, labels in val_loader:
                                    inputs, labels = inputs.to(device), labels.to(device)
                                    outputs = model(inputs)
                                    val_outputs.append(outputs)
                                    val_labels.append(labels)
                            val_outputs = torch.cat(val_outputs).squeeze().cpu()
                            val_labels = torch.cat(val_labels).squeeze().cpu()
                            val_auc = roc_auc_score(val_labels, val_outputs)
                            avg_val_auc += val_auc / inner_kf.n_splits


                            if val_auc > best_val_auc_early_stopping:
                                best_val_auc_early_stopping = val_auc
                                patience_counter = 0
                            else:
                                patience_counter += 1

                            if patience_counter >= patience:
                                break

                    if avg_val_auc > best_val_auc:
                        best_val_auc = avg_val_auc
                        best_hyperparams = {'lr': lr, 'batch_size': batch_size, 'epochs': epochs, 'hidden_dim': hidden_dim}

        train_loader = prepare_data(X_train_outer, y_train_outer, best_hyperparams['batch_size'])
        if model_type == 'Bi_LSTM':
            model = Bi_LSTMModel(input_dim=len(dynamic_predictors), hidden_dim=best_hyperparams['hidden_dim'], output_dim=1)
        else:
            model = LSTMModel(input_dim=len(dynamic_predictors), hidden_dim=best_hyperparams['hidden_dim'], output_dim=1)

        model.to(device)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=best_hyperparams['lr'])

        for epoch in range(best_hyperparams['epochs']):
            model.train()
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.squeeze(-1), labels.view(-1))
                loss.backward()
                optimizer.step()

        test_loader = prepare_data(X_test_outer, y_test_outer, best_hyperparams['batch_size'])
        test_outputs, test_labels = evaluate_model(model, test_loader, device)
        metrics = compute_metrics(test_outputs, test_labels)

        fold_accuracies.append(metrics['accuracy'])
        fold_precisions.append(metrics['precision'])
        fold_recalls.append(metrics['recall'])
        outer_test_auc.append(metrics['auc'])
        outer_test_f1.append(metrics['f1'])

        # Save the model for the current fold
        model_path = f"lstm_model_fold_{fold + 1}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")

    return {
        'accuracy': (np.mean(fold_accuracies), np.std(fold_accuracies)),
        'precision': (np.mean(fold_precisions), np.std(fold_precisions)),
        'recall': (np.mean(fold_recalls), np.std(fold_recalls)),
        'test_auc': (np.mean(outer_test_auc), np.std(outer_test_auc)),
        'test_f1': (np.mean(outer_test_f1), np.std(outer_test_f1))
    }
