import gc
from typing import Generator, Optional, Any

import numpy as np
import torch.utils.data

from torch_utils.tensorboard.tensorboard_logger import TensorboardLogger
from utils.logger.logger import Logger


def report_values(step: int, values: dict[str, any], tensorboard_logger: TensorboardLogger) -> None:
    Logger.debug(f'step {step}: {values}')
    tensorboard_logger.log_values(step, values)


def create_one_hot(labels: np.ndarray, n_classes: int) -> np.ndarray:
    one_hot: np.ndarray = np.zeros((labels.shape[0], n_classes))
    one_hot[np.arange(labels.shape[0]), labels.astype(int)] = 1.0
    return one_hot


def create_one_hot_torch(labels: torch.Tensor, n_classes: int) -> torch.Tensor:
    one_hot: torch.Tensor = torch.zeros(size=(labels.shape[0], n_classes), dtype=torch.float64, device=labels.device)
    one_hot[torch.arange(labels.shape[0]), labels.to(torch.long)] = 1.0
    return one_hot


def create_infinite_data_loader(
        data_loader: torch.utils.data.DataLoader,
        sampler: Optional[torch.utils.data.DistributedSampler]
) -> Generator[Any, None, None]:
    i: int = 0
    while True:
        if sampler is not None:
            sampler.set_epoch(i)
        for data in data_loader:
            yield data
        i += 1


def extract_model(model: torch.nn.Module) -> torch.nn.Module:
    return model.module if (
            isinstance(model, torch.nn.parallel.DistributedDataParallel) or
            isinstance(model, torch.nn.parallel.DataParallel)) else model


def free_memory() -> None:
    gc.collect()


def free_cuda_memory() -> None:
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def free_all_memory() -> None:
    free_memory()
    free_cuda_memory()
