import torch.optim as optim
from sklearn.utils import check_random_state


class Optimizer:
    def __init__(
        self,
        optimizer_name,
        lr=0.01,
        batch_size=None,
        epochs=None,
        max_epochs=100,
        weight_init=None,
        tol=1e-4,
        random_state=None,
        **kwargs,
    ):
        self.optimizer_name = optimizer_name
        self.lr = lr
        self.optimizer = None
        self.framework = None
        self.batch_size = batch_size
        self.epochs = epochs
        self.max_epochs = max_epochs
        self.tol = tol
        self.weight_init = weight_init
        self.random_state = check_random_state(random_state)
        self.kwargs = kwargs

    def get_weight_init(self):
        return self.weight_init

    def set_from_model(self, model):
        self.model = model
        if hasattr(model, "parameters"):
            self.framework = "torch"
            optimizer_class = getattr(optim, self.optimizer_name)
            if not optimizer_class:
                raise ValueError(
                    f"Optimizer {self.optimizer_name} is not a valid PyTorch optimizer."
                )
            self.optimizer = optimizer_class(
                model.parameters(), lr=self.lr, **self.kwargs
            )

    def zero_grad(self):
        if self.framework == "torch" and self.optimizer is not None:
            self.optimizer.zero_grad()

    def step(self, closure=None):
        if self.framework == "torch" and self.optimizer is not None:
            if closure is not None:
                self.optimizer.step(closure)
            else:
                self.optimizer.step()

    def get_optimizer(self):
        return self.optimizer

    def get_lr(self):
        return self.lr

    def get_num_epochs(self):
        return self.epochs if self.epochs is not None else self.max_epochs

    def get_batch_size(self):
        return self.batch_size