import logging
from collections.abc import Callable
from typing import Any
from time import time

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from models.base_model import SequentialModel
from compute.param_vector import get_vector_of_params
from compute.lipschitz import compute_lipschitz_bounds

from train_utils.train_test_model import (
    train_model,
    test_model,
    compute_loss_and_grad_vector_of_params,
)


def run_epoch(
    epoch_number: int,
    model: SequentialModel,
    train_dataloader: DataLoader,
    test_dataloader: DataLoader,
    loss_f: Callable,
    optimiser: torch.optim.Optimizer,
    lr_scheduler: Any,
    device: torch.device,
    layers_to_look_at: list,
    writer: SummaryWriter,
    batch_counter: int,
    min_num_epochs: int,
    max_num_epochs: int,
    theta_0: float = None,
    compute_lip_this_epoch: bool = True,
) -> tuple[float, float, float, float, float, torch.Tensor, int]:
    """Run one epoch of the model.
    This includes training, testing, computing the Lipschitz constant and logging epoch details.
    Note that model updates do not happen for the 0th epoch.

    Parameters
    ----------
    epoch_number
        Epoch number. If `epoch_index=0`, training is not performed (only stats are computed).
        Otherwise used for logging.
    model
        Model object.
    train_dataloader
        Train Dataloader.
    test_dataloader
        Test Dataloader.
    loss_f
        Loss function with reduction="sum".
    optimiser
        Optimiser object.
    lr_scheduler
        LR Scheduler object.
    device
        Torch device.
    layers_to_look_at
        List of layers to compute the Lipschitz constant for.
        By defualt contains one element with the last layer.
    writer
        Tensorboard Writer.
    batch_counter
        Global counter for the number of batches trained for.
    min_num_epochs
        Minimum number of epochs that the model will train for. Used for logging.
    max_num_epochs
        Maximum number of epochs that the model is allowed to train for. Used for logging.
    theta_0, optional
        Parameter vector at initialisation, by default None. Not required for the 0th epoch.
    compute_lip_this_epoch
        Flag that controls whether to compute the Lipschitz constant bounds this epoch or not.
        By default True.

    Returns
    -------
        Return the following end-of-epoch stats, computed for the model after the last update:
        mean `train_loss`, mean `train_accuracy`, mean `test_loss`, mean `test_accuracy`,
        `norm_grad_of_params_this_epoch`, `theta_t` parameter vector and the `batch_counter`.
    """
    epoch_start_time = time()

    if epoch_number != 0:  # do not train at 0, this should only evaluate stuff for the init setting
        # train
        model.train()
        batch_counter = train_model(
            model,
            train_dataloader,
            loss_f,
            optimiser,
            lr_scheduler,
            device,
            batch_counter,
        )

    # get metrics
    model.train()
    train_loss, train_accuracy, grad_of_params_this_epoch = compute_loss_and_grad_vector_of_params(
        model, train_dataloader, loss_f, device
    )
    model.eval()
    test_loss, test_accuracy = test_model(model, test_dataloader, loss_f, device)

    # get norms of gradients for each parameter
    norm_grad_of_params_this_epoch = torch.linalg.norm(grad_of_params_this_epoch, 2).item()

    theta_t = get_vector_of_params(model)

    # compute the lipschitz constant
    lipschitz_bounds = None

    if compute_lip_this_epoch:
        lipschitz_bounds = compute_lipschitz_bounds(
            model, layers_to_look_at, train_dataloader, device
        )

    # log everything
    log_epoch_stats(
        writer,
        epoch_number,
        train_loss,
        test_loss,
        train_accuracy,
        test_accuracy,
        norm_grad_of_params_this_epoch,
        min_num_epochs,
        max_num_epochs,
        epoch_start_time,
        layers_to_look_at,
        lr_scheduler,
        theta_t,
        theta_0,
        lipschitz_bounds,
    )

    return (
        train_loss,
        train_accuracy,
        test_loss,
        test_accuracy,
        norm_grad_of_params_this_epoch,
        theta_t,
        batch_counter,
    )


## Log epoch stats


def log_epoch_stats(
    writer: SummaryWriter,
    epoch_number: int,
    train_loss: float,
    test_loss: float,
    train_accuracy: float,
    test_accuracy: float,
    norm_grad_of_params_this_epoch: float,
    min_num_epochs: int,
    max_num_epochs: int,
    epoch_start_time: float,
    layers_to_look_at: list,
    lr_scheduler: Any,
    theta_t: torch.Tensor,
    theta_0: torch.Tensor = None,
    lipschitz_bounds: dict = None,
):
    # log to tensorboard
    writer.add_scalar("norm_grad_of_params", norm_grad_of_params_this_epoch, epoch_number)
    writer.add_scalar("params/norm_theta_t", torch.linalg.norm(theta_t, 2).item(), epoch_number)
    if epoch_number == 0:
        writer.add_scalar("params/norm_dtheta_t0", 0, epoch_number)
    else:
        writer.add_scalar(
            "params/norm_dtheta_t0", torch.linalg.norm(theta_t - theta_0, 2).item(), epoch_number
        )

    writer.add_scalar("loss/train", train_loss, epoch_number)
    writer.add_scalar("accuracy/train", train_accuracy, epoch_number)
    writer.add_scalar("loss/test", test_loss, epoch_number)
    writer.add_scalar("accuracy/test", test_accuracy, epoch_number)

    if lipschitz_bounds is not None:
        for layer in layers_to_look_at:
            lower, mean, rms, upper = lipschitz_bounds[str(layer)]
            writer.add_scalar(f"L_lower/layer_{layer}", lower, epoch_number)
            writer.add_scalar(f"L_mean/layer_{layer}", mean, epoch_number)
            writer.add_scalar(f"L_rms/layer_{layer}", rms, epoch_number)
            writer.add_scalar(f"L_upper/layer_{layer}", upper, epoch_number)

    # log to the log file
    extension = (
        "" if epoch_number <= min_num_epochs else f" (training extended to {max_num_epochs} epochs)"
    )
    logging.info(f"Epoch = {epoch_number}/{min_num_epochs}{extension}")
    logging.info(
        f"Gradient norm = {norm_grad_of_params_this_epoch}, LR at the end of the epoch = {lr_scheduler.get_last_lr()[0]}"
    )
    logging.info("")
    logging.info("Train/test metrics:")
    logging.info(f"Train loss = {train_loss}, train accuracy = {train_accuracy*100}%")
    logging.info(f"Test loss = {test_loss}, test accuracy = {test_accuracy*100}%")

    if lipschitz_bounds is not None:
        logging.info("")
        logging.info(f"Lipschitz bounds:")

        for layer in layers_to_look_at:
            lower, mean, rms, upper = lipschitz_bounds[str(layer)]
            logging.info(f"Lower Lipschitz @layer_{layer} = {lower}")
            logging.info(f"Mean Lipschitz @layer_{layer} = {mean}")
            logging.info(f"RMS Lipschitz @layer_{layer} = {rms}")
            logging.info(f"Upper Lipschitz @layer_{layer} = {upper}")

    logging.info("")
    time_elapsed = time() - epoch_start_time
    logging.info(f"Time for this epoch = {time_elapsed}")
    logging.info("-" * 50)
