import torch
from performances import Performances
from prediction_changes import compute_changed_predictions
import time
import torch.nn.functional as F
from cnd import cnd


def model_evaluate(model, loader, device, args,
stop_criteria_enabled=False, threshold=1e-3, patience=5):
    """
    Evaluate the model with early stopping based on accuracy convergence over recent batches.
    Returns:
    - accuracy: estimated accuracy
    - predictions: tensor of predictions
    """
    model.eval()
    correct = 0
    total = 0
    predictions = torch.empty(0, dtype=torch.long).to(device)
    recent_accuracies = []

    # ▶ smooth metric accumulators
    nll_sum = 0.0
    n_items = 0

    with torch.no_grad():
        for batch_idx, (images, labels, _) in enumerate(loader):
            images, labels = images.to(device), labels.to(device)
            output, _ = model(images)

            # ▶ stable log-likelihood
            log_probs = F.log_softmax(output, dim=1)
            # sum NLL over batch for exact mean later
            batch_nll = F.nll_loss(log_probs, labels, reduction='sum')
            nll_sum += batch_nll.item()
            n_items += labels.size(0)

            # accuracy plumbing (unchanged)
            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            predictions = torch.cat((predictions, predicted))

            # early-stop-on-accuracy convergence (unchanged)
            current_accuracy = correct / total if total > 0 else 0
            if stop_criteria_enabled:
                recent_accuracies.append(current_accuracy)
                if len(recent_accuracies) > patience:
                    recent_accuracies.pop(0)
                if len(recent_accuracies) == patience:
                    max_diff = max(recent_accuracies) - min(recent_accuracies)
                    if max_diff < threshold:
                        break

    accuracy = correct / total if total > 0 else 0

    # ▶ stash smooth metrics for the caller (don’t change return signature)
    if n_items > 0:
        avg_nll = nll_sum / n_items
        avg_logp = -avg_nll
    else:
        avg_nll = float("nan")
        avg_logp = float("nan")
    args._last_eval_nll = avg_nll
    args._last_eval_avg_logp = avg_logp

    return accuracy, predictions


def train_and_evaluate_model(
    model, loaders, criterion, optimizer, scheduler, device, args, logger):

    epochs = args.epochs
    prev_predictions = None  # To store predictions from the previous epoch
    performances = Performances()

    # Early stopping variables
    if getattr(args, "early_stopping", False):  # Check if early stopping should be activated
        patience = getattr(args, "early_stopping_patience", 10)  # Default patience
        best_test_accuracy = 0
        epochs_without_improvement = 0

    print(model)
    print(sum(p.numel() for p in model.parameters()))

    for epoch in range(epochs):

        # Track the start time of the epoch
        start_time = time.time()

        # Training phase
        args.current_epoch = epoch
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        performances_dict = {}
        logger.info(f"Epoch {epoch + 1}/{epochs} ..")

        # Store predictions for this epoch
        current_predictions = torch.empty(0, dtype=torch.long).to(device)

        # Plain-text progress
        train_loader = loaders['train_loader']

        for ii, (images, labels, _) in enumerate(train_loader):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            outputs, _ = model(images)  # Forward pass
            loss = criterion(outputs, labels)  # Calculate loss

            loss.backward()  # Backward pass
            optimizer.step()  # Update weights

            # Calculate training accuracy
            _, predicted_train = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted_train == labels).sum().item()
            running_loss += loss.item()  # Accumulate loss


        # Update learning rate
        scheduler.step() if scheduler is not None else None
        current_lr = optimizer.param_groups[0]['lr']
        logger.info(f"Epoch {epoch + 1}/{epochs}, Current LR: {current_lr:.6f}")

        # Track and print the end time of the epoch
        end_time = time.time()
        epoch_duration = end_time - start_time
        logger.info(f"Epoch {epoch + 1}/{epochs} training completed in {epoch_duration:.2f} seconds.")

        train_loss = running_loss / (ii + 1 if ii != 0 else 1)
        train_accuracy = correct_train / total_train if total_train != 0 else 0
        logger.info(f"Epoch {epoch + 1}/{epochs}, Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy * 100:.2f}%")

        # Evaluation phase
        stop_criteria_enabled = epoch != epochs - 1
        test_accuracy, _ = model_evaluate(model, loaders['test_loader'], device, args,
                                          stop_criteria_enabled=stop_criteria_enabled)
        logger.info(f"Epoch {epoch + 1}/{epochs}, Test Accuracy: {test_accuracy * 100:.2f}%")

        # Early stopping logic
        if getattr(args, "early_stopping", False) and args.current_epoch == 0:
            patience = getattr(args, "early_stopping_patience", 10)
            best_test_accuracy = 0
            epochs_without_improvement = 0
            lr_plato_counter = 0

        if getattr(args, "early_stopping", False):
            if test_accuracy > best_test_accuracy:
                best_test_accuracy = test_accuracy
                epochs_without_improvement = 0
                lr_plato_counter = 0
                logger.info("Early stopping: new best accuracy.")
            else:
                epochs_without_improvement += 1
                lr_plato_counter += 1
                remaining_epochs = patience - epochs_without_improvement
                logger.info(f"Early stopping: {remaining_epochs} epoch(s) left before patience limit is reached.")

                if getattr(args, "lr_policy", False) == "lr_plato":
                    if lr_plato_counter >= args.lr_plato_patience:
                        current_lr = optimizer.param_groups[0]['lr']
                        new_lr = current_lr * args.lr_gamma
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = new_lr
                        logger.info(f"LR Policy 'lr_plato': No improvement for {args.lr_plato_patience} epochs. "
                                    f"Reducing LR from {current_lr:.6f} to {new_lr:.6f}.")
                        lr_plato_counter = 0

            if epochs_without_improvement >= patience:
                logger.warning(f"Early stopping activated at epoch {epoch + 1}. No improvement for {patience} epochs.")
                break

        # KPA metrics
        if "KPA" in args.metrics:
            corrupted_accuracy, _ = model_evaluate(model, loaders['train_loader_corrupted'],
                                                  device, args, stop_criteria_enabled=stop_criteria_enabled)
            logger.info(f"Corrupted Sample Accuracy {epoch + 1}/{epochs}: {corrupted_accuracy * 100:.2f}% ")
            if hasattr(args, "noise_type") and args.noise_type != "clean_label":
                expected_accuracy = 100 * (1 - test_accuracy) / 9
                logger.info(f"Expected {expected_accuracy:.2f}%")
            elif args.noise_type == "hard_noise":
                expected_accuracy = 100 / args.num_classes
                logger.info(f"Expected {expected_accuracy:.2f}%")
        else:
            KPA = None

        if "KPA" in args.metrics:
            KPA, _ = model_evaluate(model, loaders['train_loader_known_corrupted'],
                                                        device, args)
            logger.info(f"KPA {epoch + 1}/{epochs}: {KPA * 100:.2f}% ")
            performances_dict['KPA'] = KPA

        else:
            performances_dict['KPA'] = None

        # CND / PC metrics
        if loaders['train_loader_fixed'] is not None:
            if "CND" in args.metrics:
                performances_dict, current_predictions = cnd(loaders['train_loader_fixed'], model, device,
                                                             performances_dict, "CND", args, logger)
            else:
                current_predictions = None

            if "PC" in args.metrics:
                if current_predictions is None:
                    _, current_predictions = model_evaluate(model, loaders['train_loader_not_corrupted'], device, args)
                if prev_predictions is not None:
                    PC = compute_changed_predictions(prev_predictions, current_predictions)
                    logger.info(f"Epoch {epoch + 1}/{epochs}, Changed predictions: {PC}")
                    performances_dict['PC'] = PC
                prev_predictions = current_predictions.clone()

        end_time = time.time()
        epoch_duration = end_time - start_time
        logger.info(f"Epoch {epoch + 1}/{epochs} completed in {epoch_duration:.2f} seconds "
                    f"(training + evaluation + metrics).")

        performances.update(train_loss, train_accuracy, test_accuracy, performances_dict, args)

    return model, performances, args
