"""
Training utilities for HKAN experiments
"""
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
import numpy as np


def train_hkan_model(model, train_data, val_data, test_data, y_train, y_val, y_test, config, device='cpu'):
    """
    Train HKAN model with proper train/validation/test split

    Args:
        model: HKAN model to train
        train_data: Training data dict
        val_data: Validation data dict
        test_data: Test data dict
        y_train, y_val, y_test: Target labels
        config: Configuration object
        device: PyTorch device

    Returns:
        Dictionary with training results and metrics
    """
    # Convert labels to tensors
    y_train_tensor = torch.FloatTensor(y_train.astype(np.float32)).to(device)
    y_val_tensor = torch.FloatTensor(y_val.astype(np.float32)).to(device)
    y_test_tensor = torch.FloatTensor(y_test.astype(np.float32)).to(device)

    # Optimizer and loss function
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )
    criterion = torch.nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)

    best_val_auc = 0.0
    best_val_acc = 0.0
    best_val_fqs = 0.0
    patience_count = 0
    train_losses = []
    val_aucs = []

    print(f"Starting HKAN model training...")
    print(f"Epochs: {config.epochs}, Learning rate: {config.learning_rate}")

    for epoch in range(config.epochs):
        # Training phase
        model.train()
        optimizer.zero_grad()

        train_predictions = model(train_data).squeeze()
        train_loss = criterion(train_predictions, y_train_tensor)

        # Add total regularization loss
        reg_loss = model.get_total_regularization_loss()
        total_loss = train_loss + reg_loss

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        optimizer.step()

        # Validation phase
        model.eval()
        with torch.no_grad():
            val_predictions = model(val_data).squeeze()
            val_loss = criterion(val_predictions, y_val_tensor)

            # Calculate metrics on validation set
            y_val_np = y_val_tensor.cpu().numpy()
            val_prob = torch.sigmoid(val_predictions).cpu().numpy()
            val_pred = (val_prob > 0.5).astype(int)

            val_auc = roc_auc_score(y_val_np, val_prob)
            val_acc = accuracy_score(y_val_np, val_pred)

            # Calculate FQS if available
            val_fqs = 0.5
            if hasattr(model, 'last_factors_for_fqs') and model.last_factors_for_fqs:
                from models import calculate_factor_quality_score
                val_fqs = calculate_factor_quality_score(model.last_factors_for_fqs, config)

        scheduler.step(val_loss)
        train_losses.append(train_loss.item())
        val_aucs.append(val_auc)

        # Early stopping based on validation AUC
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_val_acc = val_acc
            best_val_fqs = val_fqs
            patience_count = 0
        else:
            patience_count += 1

        # Print progress every 20 epochs
        if (epoch + 1) % 20 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{config.epochs} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val AUC: {val_auc:.4f} | "
                  f"Val Acc: {val_acc:.4f}")

        if patience_count >= config.patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    # Final evaluation on test set
    print(f"Training completed. Best validation AUC: {best_val_auc:.4f}")

    model.eval()
    with torch.no_grad():
        test_predictions = model(test_data).squeeze()
        test_probs = torch.sigmoid(test_predictions).cpu().numpy()
        test_pred = (test_probs > 0.5).astype(int)

        test_acc = accuracy_score(y_test, test_pred)
        test_auc = roc_auc_score(y_test, test_probs)
        test_f1 = f1_score(y_test, test_pred)

    # Calculate final FQS
    final_fqs = 0.5
    if hasattr(model, 'last_factors_for_fqs') and model.last_factors_for_fqs:
        from models import calculate_factor_quality_score
        final_fqs = calculate_factor_quality_score(model.last_factors_for_fqs, config)

    results = {
        'best_val_auc': best_val_auc,
        'best_val_acc': best_val_acc,
        'best_val_fqs': best_val_fqs,
        'test_acc': test_acc,
        'test_auc': test_auc,
        'test_f1': test_f1,
        'final_fqs': final_fqs,
        'total_params': model.count_parameters(),
        'epochs_trained': epoch + 1,
        'train_losses': train_losses,
        'val_aucs': val_aucs
    }

    return results


def train_pure_kan_model(model, X_train, X_val, X_test, y_train, y_val, y_test, config, device='cpu'):
    """
    Train Pure KAN model with proper train/validation/test split

    Args:
        model: Pure KAN model to train
        X_train, X_val, X_test: Feature arrays
        y_train, y_val, y_test: Target labels
        config: Configuration object
        device: PyTorch device

    Returns:
        Dictionary with training results and metrics
    """
    # Convert data to tensors
    X_train_tensor = torch.FloatTensor(X_train).to(device)
    X_val_tensor = torch.FloatTensor(X_val).to(device)
    X_test_tensor = torch.FloatTensor(X_test).to(device)
    y_train_tensor = torch.FloatTensor(y_train.astype(np.float32)).to(device)
    y_val_tensor = torch.FloatTensor(y_val.astype(np.float32)).to(device)

    # Optimizer and loss function
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )
    criterion = torch.nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)

    best_val_auc = 0.0
    best_val_acc = 0.0
    patience_count = 0
    train_losses = []
    val_aucs = []

    print(f"Starting Pure KAN model training...")
    print(f"Epochs: {config.epochs}, Learning rate: {config.learning_rate}")

    for epoch in range(config.epochs):
        # Training phase
        model.train()
        optimizer.zero_grad()

        train_predictions = model(X_train_tensor).squeeze()
        train_loss = criterion(train_predictions, y_train_tensor)

        # Add KAN regularization
        reg_loss = model.get_regularization_loss()
        total_loss = train_loss + reg_loss

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        optimizer.step()

        # Validation phase
        model.eval()
        with torch.no_grad():
            val_predictions = model(X_val_tensor).squeeze()
            val_loss = criterion(val_predictions, y_val_tensor)

            # Calculate metrics on validation set
            y_val_np = y_val_tensor.cpu().numpy()
            val_prob = torch.sigmoid(val_predictions).cpu().numpy()
            val_pred = (val_prob > 0.5).astype(int)

            val_auc = roc_auc_score(y_val_np, val_prob)
            val_acc = accuracy_score(y_val_np, val_pred)

        scheduler.step(val_loss)
        train_losses.append(train_loss.item())
        val_aucs.append(val_auc)

        # Early stopping based on validation AUC
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_val_acc = val_acc
            patience_count = 0
        else:
            patience_count += 1

        # Print progress every 20 epochs
        if (epoch + 1) % 20 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{config.epochs} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val AUC: {val_auc:.4f} | "
                  f"Val Acc: {val_acc:.4f}")

        if patience_count >= config.patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    # Final evaluation on test set
    print(f"Training completed. Best validation AUC: {best_val_auc:.4f}")

    model.eval()
    with torch.no_grad():
        test_predictions = model(X_test_tensor).squeeze()
        test_probs = torch.sigmoid(test_predictions).cpu().numpy()
        test_pred = (test_probs > 0.5).astype(int)

        test_acc = accuracy_score(y_test, test_pred)
        test_auc = roc_auc_score(y_test, test_probs)
        test_f1 = f1_score(y_test, test_pred)

    results = {
        'best_val_auc': best_val_auc,
        'best_val_acc': best_val_acc,
        'test_acc': test_acc,
        'test_auc': test_auc,
        'test_f1': test_f1,
        'total_params': model.count_parameters(),
        'epochs_trained': epoch + 1,
        'train_losses': train_losses,
        'val_aucs': val_aucs
    }

    return results