"""Validation module for EBM training.

This module provides the validation function that evaluates the model
on validation data with comprehensive metrics and loss computation.
"""

from __future__ import annotations

from typing import TYPE_CHECKING
import logging

import torch
from tqdm import tqdm

from .config import DataConfig, TrainingConfig, TrainingState
from .losses import (
    compute_accuracies,
    individual_contrastive_losses,
    individual_infonce_losses,
    infonce_loss,
)
from .negative_sampling import sample_negative_prompts, sample_negative_responses
from .utils import (
    apply_invalid_mask, compute_invalid_masks,
    prepare_text_sets, encode_procedural_texts, 
    update_encoding_cache, prepare_encoded_negatives,
    compute_batch_energy_stats, calculate_averages_from_history,
    format_pbar_postfix, get_cached_batch
)

if TYPE_CHECKING:
    from torch.utils.data import DataLoader, Dataset


def validate_model(
    model: torch.nn.Module,
    val_loader: DataLoader,
    full_ds: Dataset,
    device: torch.device,
    train_cfg: TrainingConfig,
    data_cfg: DataConfig,
    state: TrainingState,
    epoch: int | None = None,
    is_eval_only: bool = False,
) -> None:
    """Evaluates the energy model on a validation dataset.

    This function iterates through the validation data loader, generates all
    necessary negative samples, and computes energies, losses, and accuracies
    for each batch. It leverages the encoding cache for efficiency, encoding
    only texts that have not been seen before.

    The results (batch-level stats and running epoch averages) are not returned
    but are appended to the dictionaries and lists provided as arguments,
    modifying them in-place.
    
    Args:
        model (torch.nn.Module): The energy model to be evaluated.
        val_loader (DataLoader): The DataLoader for the validation set.
        full_ds (Dataset): The full dataset object, required for certain
            negative sampling strategies like 'off_context'.
        device (torch.device): The device to run computations on.
        train_cfg (TrainingConfig): Configuration object with training parameters.
        data_cfg (DataConfig): Configuration object with data parameters.
        state (TrainingState): The object holding all dynamic training state,
            such as metric histories and the encoding cache. This object is
            modified in-place by this function.
    """
    model.eval()
    if not is_eval_only:
        sum_val_energy = {m: 0.0 for m in state.epoch_avg_energy_val}
        n_batches = 0

    num_batches_before = len(state.batch_loss_val)

    with torch.no_grad():
        if is_eval_only:
            pbar_desc = "  [Eval]"
        else:
            pbar_desc = f"Epoch {epoch}/{train_cfg.epochs} [Val]"
        pbar = tqdm(val_loader, desc=pbar_desc, leave=False, dynamic_ncols=True)
        for indices, prompts, responses in pbar:
            try:
                # 1) Generate all negative text samples
                response_negs = {
                    "sentence_masking": sample_negative_responses(
                        sampling_type="sentence_masking", responses=responses
                    ),
                    "token_masking": sample_negative_responses(
                        sampling_type="token_masking", responses=responses
                    ),
                    "off_context": sample_negative_responses(
                        sampling_type="off_context", responses=responses,
                        indices=indices, dataset=full_ds
                    ),
                    "human": sample_negative_responses(
                        sampling_type="human", responses=responses, indices=indices,
                        dataset=full_ds, data_cfg=data_cfg
                    ),
                    "gpt2": sample_negative_responses(
                        sampling_type="gpt2", responses=responses, indices=indices,
                        dataset=full_ds, data_cfg=data_cfg
                    ),
                }

                prompt_negs = {
                    "sentence_masking_prompt": sample_negative_prompts(
                        sampling_type="sentence_masking_prompt", prompts=prompts
                    ),
                    "token_masking_prompt": sample_negative_prompts(
                        sampling_type="token_masking_prompt", prompts=prompts
                    ),
                    "off_context_prompt": sample_negative_prompts(
                        sampling_type="off_context_prompt", prompts=prompts,
                        indices=indices, dataset=full_ds
                    ),
                }
                # 2) Centralized Encoding Step
                stable_texts, procedural_texts = prepare_text_sets(
                    prompts, responses, response_negs, prompt_negs
                )

                # Update cache with missing stable texts
                cache_misses = [t for t in stable_texts if t not in state.encoding_cache]
                update_encoding_cache(cache_misses, state.encoding_cache, model)

                # Encode procedural texts
                encoded_procedural = encode_procedural_texts(procedural_texts, model)

                # 3) Retrieve all required encoded tensors
                pos_x_encoded, pos_x_mask = get_cached_batch(prompts, state.encoding_cache)
                pos_y_encoded, pos_y_mask = get_cached_batch(responses, state.encoding_cache)

                encoded_response_negs, encoded_prompt_negs = prepare_encoded_negatives(
                    response_negs, prompt_negs, state.encoding_cache, encoded_procedural
                )

                # 4) Compute energies and handle invalid samples
                pos_e = model.forward_from_encoded(
                    (pos_x_encoded, pos_x_mask), (pos_y_encoded, pos_y_mask)
                ).flatten()

                response_neg_e = {
                    m: model.forward_from_encoded(
                        (pos_x_encoded, pos_x_mask), encoded_negs
                    ).flatten()
                    for m, encoded_negs in encoded_response_negs.items()
                }

                prompt_neg_e = {
                    m: model.forward_from_encoded(
                        encoded_negs, (pos_y_encoded, pos_y_mask)
                    ).flatten()
                    for m, encoded_negs in encoded_prompt_negs.items()
                }

                # Handle invalid sentence masking samples
                if "sentence_masking" in response_neg_e:
                    invalid_response_mask = compute_invalid_masks(
                        responses, response_negs["sentence_masking"], device
                    )
                    response_neg_e["sentence_masking"] = apply_invalid_mask(
                        response_neg_e["sentence_masking"], invalid_response_mask
                    )

                if "sentence_masking_prompt" in prompt_neg_e:
                    invalid_prompt_mask = compute_invalid_masks(
                        prompts, prompt_negs["sentence_masking_prompt"], device
                    )
                    prompt_neg_e["sentence_masking_prompt"] = apply_invalid_mask(
                        prompt_neg_e["sentence_masking_prompt"], invalid_prompt_mask
                    )

                all_neg_e = {**response_neg_e, **prompt_neg_e}

                # Handle expanded InfoNCE with off-context batching
                neg_energies_for_loss = all_neg_e.copy()
                if train_cfg.loss_strategy == "infonce_expanded":
                    neg_energies_for_loss.pop("off_context", None)
                    neg_energy_list = list(neg_energies_for_loss.values())

                    batch_size = len(prompts)
                    if batch_size > 1:
                        off_context_batch_energies_list = []
                        for i in range(1, batch_size):
                            off_context_batch_negs = sample_negative_responses(
                                sampling_type="off_context_batch", responses=responses,
                                offset=i,
                            )
                            negs_encoded, negs_mask = get_cached_batch(
                                off_context_batch_negs, state.encoding_cache
                            )
                            energies = model.forward_from_encoded(
                                (pos_x_encoded, pos_x_mask), (negs_encoded, negs_mask)
                            ).flatten()
                            off_context_batch_energies_list.append(energies)
                            neg_energy_list.append(energies)

                        if off_context_batch_energies_list:
                            avg_off_context_batch_e = torch.stack(
                                off_context_batch_energies_list, dim=1
                            ).mean(dim=1)
                            all_neg_e["off_context_batch_avg"] = avg_off_context_batch_e
                else:
                    neg_energy_list = list(neg_energies_for_loss.values())

                # Log energy values
                batch_energy_means = compute_batch_energy_stats(pos_e, all_neg_e)
                state.batch_energy_val.append(batch_energy_means)

                if not is_eval_only:
                    for m, e in all_neg_e.items():
                        # Use isfinite to ensure the infinity values don't pollute the logs
                        finite_e = e[torch.isfinite(e)]
                        if finite_e.numel() > 0 and sum_val_energy.get(m) is not None:
                            sum_val_energy[m] += finite_e.mean().item()
                    sum_val_energy["positive"] += pos_e.mean().item()
                    n_batches += 1

                # 5) Loss calculation using different strategies
                if train_cfg.loss_strategy in ["sum", "weighted_sum", "sequential"]:
                    individual_losses = individual_contrastive_losses(
                        pos_e, all_neg_e, train_cfg.margin
                    )
                    main_loss = sum(
                        loss for loss in individual_losses.values() if not torch.isnan(loss)
                    )
                else:  # infonce and infonce_expanded
                    main_loss = infonce_loss(pos_e, neg_energy_list, temperature=train_cfg.temperature)
                    individual_losses = individual_infonce_losses(
                        pos_e, all_neg_e, temperature=train_cfg.temperature
                    )

                batch_loss_means = {
                    "overall": main_loss.item(),
                    **{m: loss.item() for m, loss in individual_losses.items()},
                }
                state.batch_loss_val.append(batch_loss_means)

                accuracies = compute_accuracies(pos_e, all_neg_e)
                state.batch_accuracy_val.append(accuracies)

                # Update the progress bar using the new helper function
                pbar.set_postfix(
                    format_pbar_postfix(
                        batch_loss_stats=batch_loss_means,
                        batch_acc_stats=accuracies,
                        batch_energy_stats=batch_energy_means,
                        loss_history=state.batch_loss_val,
                    )
                )
                
            except Exception as e:
                logging.error(f"\nError in validation batch: {e}")
                continue

    # Finalize averages per epoch if epochs exist
    if not is_eval_only:
        for m in state.epoch_avg_energy_val:
            state.epoch_avg_energy_val[m].append(
                sum_val_energy.get(m, 0.0) / n_batches if n_batches > 0 
                else 0.0)

        current_run_losses = state.batch_loss_val[num_batches_before:]
        current_run_accs = state.batch_accuracy_val[num_batches_before:]

        final_avg_losses = calculate_averages_from_history(current_run_losses)
        for m in state.epoch_avg_loss_val:
            state.epoch_avg_loss_val[m].append(final_avg_losses.get(m, 0.0))
            
        final_avg_accs = calculate_averages_from_history(current_run_accs)
        for m in state.epoch_avg_accuracy_val:
            state.epoch_avg_accuracy_val[m].append(final_avg_accs.get(m, 0.0))
