from collections.abc import Callable
from typing import Any

import torch
from torch.utils.data import DataLoader

from models.base_model import SequentialModel
from compute.grad_wrt_to_params import get_grad_vector_of_params


def train_model(
    model: SequentialModel,
    train_dataloader: DataLoader,
    loss_f: Callable,
    optimiser: torch.optim.Optimizer,
    lr_scheduler: Any,
    device: torch.device,
    batch_counter: int,
) -> int:
    for x, y in train_dataloader:
        curr_batch_size = x.shape[0]
        x = x.to(device)
        y = y.to(device)

        # compute loss
        output = model(x)
        loss = loss_f(output, y) / curr_batch_size

        # train model
        optimiser.zero_grad()
        loss.backward()

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 50.0)

        # step the optimiser, lr scheduler
        optimiser.step()
        lr_scheduler.step()

        # increase batch counter
        batch_counter += 1
    return batch_counter


def compute_loss_and_grad_vector_of_params(
    model: SequentialModel,
    dataloader: DataLoader,
    loss_f: Callable,
    device: torch.device,
) -> tuple[float, float, torch.Tensor]:
    loss = 0
    n_correct = 0
    n_samples = 0

    # ∇_W L(W,X) is a flattened vector of all parameters ∈ R^n_params
    grad_vec = None

    for x, y in dataloader:
        model.zero_grad()

        curr_batch_size = x.shape[0]
        x = x.to(device)
        y = y.to(device)

        # compute loss
        output = model(x)
        curr_loss = loss_f(output, y)
        loss += curr_loss

        # compute gradient vector of parameters
        curr_loss.backward()
        curr_grad_vec = get_grad_vector_of_params(model)
        grad_vec = curr_grad_vec if grad_vec is None else grad_vec + curr_grad_vec

        # compute accuracy
        # get class number
        prediction = torch.argmax(output, 1)
        if len(y.shape) > 1:
            # if labels are one-hot-encoded, transform to class numbers
            y = torch.argmax(y, 1)
        n_correct += (prediction == y).sum().item()
        n_samples += curr_batch_size

    # normalise the loss and accuracy
    accuracy = n_correct / n_samples
    loss = loss / n_samples
    grad_vec = grad_vec / n_samples
    return loss.item(), accuracy, grad_vec


def test_model(
    model: SequentialModel,
    dataloader: DataLoader,
    loss_f: Callable,
    device: torch.device,
) -> tuple[float, float]:
    with torch.no_grad():
        loss = 0
        n_correct = 0
        n_samples = 0

        for x, y in dataloader:
            curr_batch_size = x.shape[0]
            x = x.to(device)
            y = y.to(device)

            output = model(x)
            curr_loss = loss_f(output, y)
            loss += curr_loss

            # compute accuracy
            # get class number
            prediction = torch.argmax(output, 1)
            if len(y.shape) > 1:
                # if labels are one-hot-encoded
                # if labels are one-hot-encoded, transform to class numbers
                y = torch.argmax(y, 1)
            n_correct += (prediction == y).sum().item()
            n_samples += curr_batch_size

        accuracy = n_correct / n_samples
        loss = loss / n_samples
        return loss.item(), accuracy
