import logging
import os
from typing import Union, Optional, List, Iterator

import torch
import torch.nn as nn

from . import Callback


class ClipGradValue(Callback):
    def __init__(self, value, net_params: Optional[Iterator[torch.nn.Parameter]] = None):
        super(ClipGradValue, self).__init__()
        self.value = value
        self.net_params = net_params

    def init(self):
        if self.net_params is None:
            net: nn.Module = self.trainer.net
            self.net_params = net.parameters()

    def on_backward_end(self, batch_number: int):
        self.init()
        torch.nn.utils.clip_grad_value_(self.net_params, self.value)


class ClipGradNorm(Callback):
    def __init__(self, norm, net_params=None):
        super(ClipGradNorm, self).__init__()
        self.norm = norm
        self.net_params = net_params

    def init(self):
        if self.net_params is None:
            net: nn.Module = self.trainer.net
            self.net_params = net.parameters()

    def on_backward_end(self, batch_number: int):
        self.init()
        torch.nn.utils.clip_grad_norm_(self.net_params, self.norm)
