"""Code for training guides."""
import os
import time

from tqdm import tqdm
import torch
import wandb

from calnf.datasets.dataset import Dataset
from calnf.guides.guide import Guide
from calnf.metrics.metric import Metric


def train(
    name: str,
    device: torch.device,
    dataset: Dataset,
    guide: Guide,
    epochs: int,
    metrics: list[tuple[Metric, int]],
    visualize_every_n: int = 0,
    lr: float = 1e-3,
):
    """Train the guide.

    Args:
        name (str): Name of the run.
        device (torch.device): Device to use for training.
        dataset (Dataset): Dataset.
        guide (Guide): Guide.
        epochs (int): Number of epochs.
        metrics (list[tuple[Metric, int]]): List of tuples with metrics and their
            frequency (0 = only run at the last step).
        visualize_every_n (int): Visualize the guide every n epochs (0 = only at the
            last step).
        lr (float): Learning rate.
    """
    # # Get dataloaders for nominal and target data
    dataset.configure_nominal_data(device)
    dataset.configure_target_data(device)

    # Collate all test data for evaluation
    all_test_nominal = []
    all_test_target = []
    for obs_nominal in dataset.nominal_test_loader:
        all_test_nominal.append(obs_nominal)

    for obs_target in dataset.target_test_loader:
        all_test_target.append(obs_target)

    if isinstance(all_test_nominal[0], torch.Tensor):
        all_test_nominal = torch.cat(all_test_nominal)
        all_test_target = torch.cat(all_test_target)

    # Get the optimizer
    optimizer = guide.configure_optimizer(lr=lr)

    # Make somewhere to save checkpoints
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    os.makedirs(f"checkpoints/{name}/{timestamp}/", exist_ok=True)

    # Run the training loop
    pbar = tqdm(range(epochs))
    for epoch in pbar:
        log_packet = {}  # for saving to wandb

        # The nominal dataset is quite large, so break it into mini-batches
        total_losses = {}  # Track total losses for logging
        minibatches = 0
        for obs_nominal_minibatch in dataset.nominal_train_loader:
            # (the target dataset is small enough to fit in memory, so just use it all)
            obs_target = next(iter(dataset.target_train_loader))

            # Perform a training step
            optimizer.zero_grad()

            loss, losses = guide.loss(
                dataset=dataset,
                n_nominal=len(obs_nominal_minibatch),
                obs_nominal=obs_nominal_minibatch.to(device),
                n_target=len(obs_target),
                obs_target=obs_target.to(device),
            )

            loss.backward()
            grad_norms = guide.clip_grad_norm()
            if ~(torch.isnan(loss) | torch.isinf(loss)):
                optimizer.step()

            pbar.set_description(f"Loss: {loss.detach().cpu().item():.2f}")

            # Update total losses
            minibatches += 1
            losses.update(grad_norms)
            for key, value in losses.items():
                total_losses[key] = total_losses.get(key, 0.0) + value

            wandb.log({f"Train/{key}": value for key, value in losses.items()})

        log_packet = {}  # for saving to wandb

        # Log total losses
        for key, value in total_losses.items():
            log_packet["Train/Avg/" + key] = total_losses[key] / minibatches

        # Visualize the guide
        last_epoch = epoch == epochs - 1
        if last_epoch or (visualize_every_n != 0 and epoch % visualize_every_n == 0):
            dataset.visualize(guide.nominal_distribution(), guide.target_distribution())

        # Compute metrics
        for metric, freq in metrics:
            skip_epoch = freq == 0 or epoch % freq != 0
            if not last_epoch and skip_epoch:
                continue

            metric_values = metric(
                device=device,
                dist=guide.target_distribution(),
                nominal_test_loader=dataset.nominal_test_loader,
                target_test_loader=dataset.target_test_loader,
            )

            for key, value in metric_values.items():
                log_packet["Test/" + key] = metric_values[key]

        # Log to wandb
        wandb.log(log_packet)

        # Save the guide
        guide.save(f"checkpoints/{name}/epoch_{epoch}.pt")

    # Save the guide to wandb
    wandb.save(f"checkpoints/{name}/epoch_{epoch}.pt")
