import math
from logging import Logger
from typing import List

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, RandomSampler, DataLoader

from algorithms.convergence_algorithms.typing import SizedDataset


def train_gradient_network(
    gradient_loss_func: Module,
    optimizer,
    dataset: SizedDataset,
    batch_size: int,
    logger: Logger,
    testset: SizedDataset = None,
    max_test_batch_size: int = None,
) -> List[float]:
    test_loss = []
    best_loss = math.inf
    no_improvement_counter = 0

    sampler = BatchSampler(
        RandomSampler(range(len(dataset))), batch_size=batch_size, drop_last=False
    )
    data_loader = DataLoader(dataset, sampler=sampler)
    for i, (x_i, x_j, y_i, y_j) in enumerate(data_loader):
        optimizer.zero_grad()
        x_i, x_j, y_i, y_j = x_i[0], x_j[0], y_i[0], y_j[0]
        loss = gradient_loss_func(x_i, x_j, y_i, y_j)
        loss.backward()
        optimizer.step()

        if testset is not None:
            with torch.no_grad():
                test_loader = DataLoader(
                    testset, batch_size=min(max_test_batch_size, len(testset))
                )
                total_loss = 0
                for test_x_i, test_x_j, test_y_i, test_y_j in test_loader:
                    total_loss += gradient_loss_func(test_x_i, test_x_j, test_y_i, test_y_j)
            test_loss.append(total_loss.cpu().item())
            if total_loss < best_loss:
                no_improvement_counter = 0
                best_loss = total_loss
            else:
                no_improvement_counter += 1
            if no_improvement_counter > 10:
                logger.info(f"Stopping early after {i}")
                break
    return test_loss


def step_model_with_gradient(model, gradient: Tensor, optimizer: Optimizer):
    optimizer.zero_grad()
    gradient_index = 0
    for layer in model.children():
        for model_parameter in layer.parameters():
            gradient_offset = model_parameter.numel()
            parameter_gradient = (
                gradient[gradient_index : gradient_index + gradient_offset]
                .reshape(model_parameter.shape)
                .clone()
            )
            with torch.no_grad():
                model_parameter.grad = parameter_gradient
            gradient_index += gradient_offset
    optimizer.step()
