import os
import random as pyrand

import numpy as np

import torch
from torch import nn


def get_epsilon(device=None):
    epsilon = torch.tensor(1e-7, dtype=torch.float32, device=device)

    return epsilon

def reduce_mean(value):
    if value.ndim != 0:
        return value.mean()

    return value

def get_num_parameters(model, trainable_only=False):
    return sum(params.numel() for params in model.parameters() if params.requires_grad or (not trainable_only))


def scheduler_update_step(scheduler, scheduler_step, batch_idx, num_batches):
    if scheduler is None:
        return
    if ((scheduler_step is not None) and
        (batch_idx % scheduler_step == 0)) or \
            (batch_idx == num_batches):
        scheduler.step()
    return scheduler.get_last_lr()[0]

def optimizer_update_step(model, optimizer, scaler, grad_accum_step, batch_idx, num_batches, max_grad_norm):
    if optimizer is None:
        return
    
    if (batch_idx % grad_accum_step == 0) or \
            (batch_idx == num_batches):

        if max_grad_norm is not None:
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        if scaler is None:
            optimizer.step()
        else:
            scaler.step(optimizer)
            scaler.update()
            
        optimizer.zero_grad()

def get_grads(model):
    grads = []
    for param in model.parameters():
        if hasattr(param, 'grad') and (param.grad is not None):
            grads.append(param.grad.norm())
    return grads

def dict_to_device(my_dict, device=None):
    for key in my_dict:
        if not isinstance(my_dict[key], torch.Tensor):
            continue
        my_dict[key] = my_dict[key].to(device)
    return my_dict


def get_device():
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    return device

def get_cuda_current_device():
    device = None
    if str(get_device()) == 'cuda':
        device = torch.cuda.current_device()
    return device

def set_seed(torch_seed=0, np_seed=None, py_seed=None):
    if torch_seed is None:
        return
    os.environ['PYTHONHASHSEED'] = str(py_seed) if py_seed else str(torch_seed)
    pyrand.seed(py_seed if py_seed else torch_seed)
    np.random.seed(np_seed if np_seed else torch_seed)
    torch.manual_seed(torch_seed)

def load_checkpoint(module, path, device=None, strict=True, warn=True):
    module_state = torch.load(path, map_location=device)
    module.load_state_dict(module_state, strict=strict)        
    if (not strict) and warn:
        for p_name, _ in module.named_parameters():
            if p_name not in module_state:
                msg = f'load_checkpoint(...) ---> {p_name} parameter is missing.'
                yield msg

def set_requires_grad(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad