import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold
from model import TransformerModel, BidirectionalTransformerModel
from data_loading import prepare_data, create_padding_mask
from evaluate import evaluate_model, compute_metrics
import numpy as np
import pandas as pd


def train_model(X, y, dynamic_predictors, learning_rates, batch_sizes, epochs, hidden_dims, num_heads, num_layers,
                inner_kf_splits=2, model_type ='Transformer'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    outer_kf = KFold(n_splits=2, shuffle=True, random_state=42)
    inner_kf = KFold(n_splits=inner_kf_splits, shuffle=True, random_state=42)

    outer_test_auc = []
    outer_test_f1 = []
    fold_accuracies, fold_precisions, fold_recalls = [], [], []

    if isinstance(X, np.ndarray):
        X = pd.DataFrame(X)
    if isinstance(y, np.ndarray):
        y = pd.Series(y)

    X = X.reset_index(drop=True)
    y = y.reset_index(drop=True)

    for fold, (train_index, test_index) in enumerate(outer_kf.split(X)):
        X_train_outer, X_test_outer = X.iloc[train_index], X.iloc[test_index]
        y_train_outer, y_test_outer = y.iloc[train_index], y.iloc[test_index]

        best_hyperparams = None
        best_val_auc = 0

        for lr in learning_rates:
            for batch_size in batch_sizes:
                for hidden_dim in hidden_dims:
                    avg_val_auc = 0
                    for inner_fold, (train_index_inner, val_index_inner) in enumerate(inner_kf.split(X_train_outer)):
                        X_train_inner, X_val_inner = X_train_outer.iloc[train_index_inner], X_train_outer.iloc[
                            val_index_inner]
                        y_train_inner, y_val_inner = y_train_outer.iloc[train_index_inner], y_train_outer.iloc[
                            val_index_inner]

                        train_loader = prepare_data(X_train_inner, y_train_inner, batch_size)
                        val_loader = prepare_data(X_val_inner, y_val_inner, batch_size)

                        if model_type == 'BidirectionalTransformer':
                            model = BidirectionalTransformerModel(input_dim=len(dynamic_predictors),
                                                                  hidden_dim=hidden_dim,
                                                                  output_dim=1, num_heads=num_heads,
                                                                  num_layers=num_layers)
                        else:
                            model = TransformerModel(input_dim=len(dynamic_predictors), hidden_dim=hidden_dim,
                                                     output_dim=1, num_heads=num_heads, num_layers=num_layers)
                        model.to(device)
                        criterion = nn.BCEWithLogitsLoss()
                        optimizer = optim.Adam(model.parameters(), lr=lr)

                        best_val_auc_early_stopping = 0
                        patience = 3
                        patience_counter = 0

                        for epoch in range(epochs):
                            model.train()
                            total_loss = 0
                            for inputs, labels in train_loader:
                                inputs, labels = inputs.to(device), labels.to(device)
                                tgt_batch = torch.zeros_like(inputs).to(device)
                                train_padding_mask = create_padding_mask(inputs)

                                optimizer.zero_grad()
                                outputs = model(inputs, tgt_batch, src_mask=train_padding_mask)
                                loss = criterion(outputs.squeeze(-1), labels.view(-1))
                                loss.backward()
                                optimizer.step()
                                total_loss += loss.item()

                            model.eval()
                            val_outputs = []
                            val_labels = []
                            with torch.no_grad():
                                for inputs, labels in val_loader:
                                    inputs, labels = inputs.to(device), labels.to(device)
                                    tgt_batch = torch.zeros_like(inputs).to(device)
                                    val_padding_mask = create_padding_mask(inputs)
                                    outputs = model(inputs, tgt_batch, src_mask=val_padding_mask)
                                    val_outputs.append(outputs)
                                    val_labels.append(labels)
                            val_outputs = torch.cat(val_outputs).squeeze().cpu()
                            val_labels = torch.cat(val_labels).squeeze().cpu()
                            val_auc = roc_auc_score(val_labels, val_outputs)
                            avg_val_auc += val_auc / inner_kf.n_splits

                            if val_auc > best_val_auc_early_stopping:
                                best_val_auc_early_stopping = val_auc
                                patience_counter = 0
                            else:
                                patience_counter += 1

                            if patience_counter >= patience:
                                break

                    if avg_val_auc > best_val_auc:
                        best_val_auc = avg_val_auc
                        best_hyperparams = {'lr': lr, 'batch_size': batch_size, 'epochs': epochs,
                                            'hidden_dim': hidden_dim}

        train_loader = prepare_data(X_train_outer, y_train_outer, best_hyperparams['batch_size'])
        model = TransformerModel(input_dim=len(dynamic_predictors), hidden_dim=best_hyperparams['hidden_dim'],
                                 output_dim=1, num_heads=num_heads, num_layers=num_layers).to(device)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(model.parameters(), lr=best_hyperparams['lr'])

        for epoch in range(best_hyperparams['epochs']):
            model.train()
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                tgt_batch = torch.zeros_like(inputs).to(device)
                train_padding_mask = create_padding_mask(inputs)

                optimizer.zero_grad()
                outputs = model(inputs, tgt_batch, src_mask=train_padding_mask)
                loss = criterion(outputs.squeeze(-1), labels.view(-1))
                loss.backward()
                optimizer.step()

        test_loader = prepare_data(X_test_outer, y_test_outer, best_hyperparams['batch_size'])
        test_outputs, test_labels = evaluate_model(model, test_loader, device)
        metrics = compute_metrics(test_outputs, test_labels)

        fold_accuracies.append(metrics['accuracy'])
        fold_precisions.append(metrics['precision'])
        fold_recalls.append(metrics['recall'])
        outer_test_auc.append(metrics['auc'])
        outer_test_f1.append(metrics['f1'])

        model_path = f"transformer_model_fold_{fold + 1}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")

    return {
        'accuracy': (np.mean(fold_accuracies), np.std(fold_accuracies)),
        'precision': (np.mean(fold_precisions), np.std(fold_precisions)),
        'recall': (np.mean(fold_recalls), np.std(fold_recalls)),
        'test_auc': (np.mean(outer_test_auc), np.std(outer_test_auc)),
        'test_f1': (np.mean(outer_test_f1), np.std(outer_test_f1))
    }
