"""Training loop module for EBM training.

This module provides the `train_one_epoch` function, which handles the core logic
for training the energy-based model for a single epoch. It orchestrates data
loading, negative sampling, encoding, loss calculation, backpropagation, and
metric tracking.
"""

from __future__ import annotations

from typing import TYPE_CHECKING
import logging

import torch
from torch.optim import Optimizer
from tqdm import tqdm

from .config import DataConfig, TrainingConfig, TrainingState
from .losses import (
    compute_accuracies,
    individual_contrastive_losses,
    individual_infonce_losses,
    infonce_loss,
    summed_contrastive_loss,
)
from .negative_sampling import sample_negative_prompts, sample_negative_responses
from .utils import (
    apply_invalid_mask, compute_invalid_masks,
    prepare_text_sets, encode_procedural_texts, 
    update_encoding_cache, prepare_encoded_negatives,
    compute_batch_energy_stats, calculate_averages_from_history,
    format_pbar_postfix, get_cached_batch,
    save_checkpoint, upload_file_to_drive,
)

if TYPE_CHECKING:
    from torch.utils.data import DataLoader, Dataset


def train_one_epoch(
    epoch: int,
    model: torch.nn.Module,
    train_loader: DataLoader,
    full_ds: Dataset,
    device: torch.device,
    train_cfg: TrainingConfig,
    data_cfg: DataConfig,
    opt: Optimizer,
    state: TrainingState,
) -> None:
    """Trains the energy model for a single epoch.

    This function orchestrates the main training loop for one pass over the
    training data. It handles negative sampling, encoding, forward and backward
    passes, optimizer steps, and metric tracking. The results of the training
    run (metrics, updated step counters, etc.) are mutated in the provided
    `state` object.

    Args:
        epoch (int): The current epoch number, used for logging.
        model (torch.nn.Module): The energy model to be trained.
        train_loader (DataLoader): The DataLoader for the training set.
        full_ds (Dataset): The complete dataset object, required for certain
            negative sampling strategies like 'off_context'.
        device (torch.device): The device (CPU or CUDA) to run computations on.
        train_cfg (TrainingConfig): Configuration object with training parameters
            (e.g., loss strategy, learning rate).
        data_cfg (DataConfig): Configuration object with data parameters
            (e.g., column names).
        opt (Optimizer): The optimizer for updating model weights.
        state (TrainingState): The mutable object holding all dynamic training
            state, which will be updated by this function.
    """
    model.train()
    sum_train_energy = {m: 0.0 for m in state.epoch_avg_energy_train}
    n_batches = 0

    num_batches_before = len(state.batch_loss_train)

    pbar = tqdm(
        train_loader, desc=f"Epoch {epoch}/{train_cfg.epochs} [Train]",
        leave=False, dynamic_ncols=True,
    )
    for batch_idx, (indices, prompts, responses) in enumerate(pbar):
        if batch_idx < state.resume_batch_idx:
            continue
        try:
            # 1) Generate all negative text samples for the batch
            response_negs = {
                "sentence_masking": sample_negative_responses(
                    sampling_type="sentence_masking", responses=responses
                ),
                "token_masking": sample_negative_responses(
                    sampling_type="token_masking", responses=responses
                ),
                "off_context": sample_negative_responses(
                    sampling_type="off_context", responses=responses, 
                    indices=indices, dataset=full_ds
                ),
                "human": sample_negative_responses(
                    sampling_type="human", responses=responses, indices=indices, 
                    dataset=full_ds, data_cfg=data_cfg
                ),
                "gpt2": sample_negative_responses(
                    sampling_type="gpt2", responses=responses, indices=indices, 
                    dataset=full_ds, data_cfg=data_cfg
                ),
            }

            prompt_negs = {
                "sentence_masking_prompt": sample_negative_prompts(
                    sampling_type="sentence_masking_prompt", prompts=prompts
                ),
                "token_masking_prompt": sample_negative_prompts(
                    sampling_type="token_masking_prompt", prompts=prompts
                ),
                "off_context_prompt": sample_negative_prompts(
                    sampling_type="off_context_prompt", prompts=prompts, 
                    indices=indices, dataset=full_ds
                ),
            }

            # 2) Centralized encoding step
            stable_texts, procedural_texts = prepare_text_sets(
                prompts, responses, response_negs, prompt_negs
            )

            # Update cache with missing stable texts
            cache_misses = [t for t in stable_texts if t not in state.encoding_cache]
            update_encoding_cache(cache_misses, state.encoding_cache, model)

            # Encode procedural texts for one-time use
            encoded_procedural = encode_procedural_texts(procedural_texts, model)

            # 3) Retrieve all required encoded tensors
            pos_x_encoded, pos_x_mask = get_cached_batch(prompts, state.encoding_cache)
            pos_y_encoded, pos_y_mask = get_cached_batch(responses, state.encoding_cache)

            encoded_response_negs, encoded_prompt_negs = prepare_encoded_negatives(
                response_negs, prompt_negs, state.encoding_cache, encoded_procedural
            )

            # 4) Compute energies and handle invalid samples
            pos_e = model.forward_from_encoded(
                (pos_x_encoded, pos_x_mask), (pos_y_encoded, pos_y_mask)
            ).flatten()

            response_neg_e = {
                m: model.forward_from_encoded(
                    (pos_x_encoded, pos_x_mask), encoded_negs
                ).flatten()
                for m, encoded_negs in encoded_response_negs.items()
            }
            prompt_neg_e = {
                m: model.forward_from_encoded(
                    encoded_negs, (pos_y_encoded, pos_y_mask)
                ).flatten()
                for m, encoded_negs in encoded_prompt_negs.items()
            }

            # Set energy to infinity for invalid sentence-masked samples
            if "sentence_masking" in response_neg_e:
                invalid_mask = compute_invalid_masks(
                    responses, response_negs["sentence_masking"], device
                )
                response_neg_e["sentence_masking"] = apply_invalid_mask(
                    response_neg_e["sentence_masking"], invalid_mask
                )
            if "sentence_masking_prompt" in prompt_neg_e:
                invalid_mask = compute_invalid_masks(
                    prompts, prompt_negs["sentence_masking_prompt"], device
                )
                prompt_neg_e["sentence_masking_prompt"] = apply_invalid_mask(
                    prompt_neg_e["sentence_masking_prompt"], invalid_mask
                )

            all_neg_e = {**response_neg_e, **prompt_neg_e}

            # 5) Loss calculation and backward pass based on strategy
            main_loss = None
            individual_losses = {}

            if train_cfg.loss_strategy in ["sum", "weighted_sum"]:
                weights = (
                    {
                        "off_context": train_cfg.off_context_weight,
                        "off_context_prompt": train_cfg.off_context_weight,
                    }
                    if train_cfg.loss_strategy == "weighted_sum"
                    else None
                )
                main_loss = summed_contrastive_loss(
                    pos_e, all_neg_e, train_cfg.margin, weights
                )
                individual_losses = individual_contrastive_losses(
                    pos_e, all_neg_e, train_cfg.margin
                )

                if isinstance(main_loss, torch.Tensor) and main_loss > 0:
                    opt.zero_grad()
                    main_loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    opt.step()

            elif train_cfg.loss_strategy == "sequential":
                individual_losses = individual_contrastive_losses(
                    pos_e, all_neg_e, train_cfg.margin
                )
                for loss in individual_losses.values():
                    if not torch.isnan(loss.item()):
                        opt.zero_grad()
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        opt.step()
                main_loss = sum(
                    loss for loss in individual_losses.values() if not torch.isnan(loss.item())
                )

            elif train_cfg.loss_strategy in ["infonce", "infonce_expanded"]:
                neg_energies_for_loss = all_neg_e.copy()
                if train_cfg.loss_strategy == "infonce_expanded":
                    neg_energies_for_loss.pop("off_context", None)

                neg_energy_list = list(neg_energies_for_loss.values())

                if train_cfg.loss_strategy == "infonce_expanded":
                    # Create B-1 new sets of in-batch negatives
                    batch_size = len(prompts)
                    if batch_size > 1:
                        off_context_batch_energies_list = []
                        for i in range(1, batch_size):
                            off_context_negs = sample_negative_responses(
                                sampling_type="off_context_batch", responses=responses,
                                offset=i,
                            )
                            negs_encoded, negs_mask = get_cached_batch(
                                off_context_negs, state.encoding_cache
                            )
                            energies = model.forward_from_encoded(
                                (pos_x_encoded, pos_x_mask), (negs_encoded, negs_mask)
                            ).flatten()
                            off_context_batch_energies_list.append(energies)
                            neg_energy_list.append(energies)
                        
                        if off_context_batch_energies_list:
                            avg_e = torch.stack(off_context_batch_energies_list, dim=1).mean(dim=1)
                            all_neg_e["off_context_batch_avg"] = avg_e

                main_loss = infonce_loss(pos_e, neg_energy_list, temperature=train_cfg.temperature)
                individual_losses = individual_infonce_losses(
                    pos_e, all_neg_e, temperature=train_cfg.temperature
                )

                opt.zero_grad()
                main_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                opt.step()

            state.global_step += 1

            # 6) Log metrics and update progress bar
            with torch.no_grad():
                batch_energy_means = compute_batch_energy_stats(pos_e, all_neg_e)
                state.batch_energy_train.append(batch_energy_means)

                batch_loss_means = {
                    "overall": main_loss.item() if isinstance(main_loss, torch.Tensor) else 0.0,
                    **{m: loss.item() for m, loss in individual_losses.items()},
                }
                state.batch_loss_train.append(batch_loss_means)

                accuracies = compute_accuracies(pos_e, all_neg_e)
                state.batch_accuracy_train.append(accuracies)

                # Update epoch-level sums for final averaging
                for m, e_tensor in all_neg_e.items():
                    finite_e = e_tensor[torch.isfinite(e_tensor)]
                    if finite_e.numel() > 0 and sum_train_energy.get(m) is not None:
                        sum_train_energy[m] += finite_e.mean().item()
                sum_train_energy["positive"] += pos_e.mean().item()
                n_batches += 1

                pbar.set_postfix(
                    format_pbar_postfix(
                        batch_loss_stats=batch_loss_means,
                        batch_acc_stats=accuracies,
                        batch_energy_stats=batch_energy_means,
                        loss_history=state.batch_loss_train,
                    )
                )

            # 7) Save checkpoint every N batches if specified
            if train_cfg.save_every_n_batches > 0 and state.global_step % train_cfg.save_every_n_batches == 0:
                print("")
                checkpoint_path = save_checkpoint(
                    epoch, state.global_step, model, opt,
                    state.epoch_avg_energy_train, state.epoch_avg_energy_val,
                    state.epoch_avg_loss_train, state.epoch_avg_loss_val,
                    state.epoch_avg_accuracy_train, state.epoch_avg_accuracy_val,
                    state.batch_loss_train, state.batch_loss_val,
                    state.batch_accuracy_train, state.batch_accuracy_val,
                    state.batch_energy_train, state.batch_energy_val,
                    state.best_acc, state.encoding_cache, train_cfg.results_dir,
                )
                if train_cfg.upload_to_gdrive:
                    upload_file_to_drive(
                        train_cfg.gdrive_folder_id, checkpoint_path,
                        train_cfg.gdrive_creds_path, train_cfg.results_dir,
                    )

        except Exception as e:
            logging.error(f"\nError in training batch {batch_idx}: {e}")
            continue

    # 8) Finalize epoch averages
    for m in state.epoch_avg_energy_train:
        state.epoch_avg_energy_train[m].append(
            sum_train_energy.get(m, 0.0) / n_batches if n_batches > 0 
            else 0.0)

    current_epoch_losses = state.batch_loss_train[num_batches_before:]
    current_epoch_accs = state.batch_accuracy_train[num_batches_before:]
    
    final_avg_losses = calculate_averages_from_history(current_epoch_losses)
    for m in state.epoch_avg_loss_train:
        state.epoch_avg_loss_train[m].append(final_avg_losses.get(m, 0.0))
        
    final_avg_accs = calculate_averages_from_history(current_epoch_accs)
    for m in state.epoch_avg_accuracy_train:
        state.epoch_avg_accuracy_train[m].append(final_avg_accs.get(m, 0.0))
