import torch
try:
    import neptune.new as neptune
except ImportError:
    pass

class TrainConfig:
    def __init__(self, n_epochs, checkpoint_dir=None, batch_size=64, checkpoint_last_epochs=50, log_step=1, clip_grad_norm=1.0, lr=0.0001, device=torch.device('cpu'), use_sqrt_cost=False, rng_seed=123):
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.checkpoint_last_epochs = checkpoint_last_epochs
        self.device = device
        self.clip_grad_norm = clip_grad_norm
        self.log_step = log_step
        self.ckpt = checkpoint_dir
        self.lr = lr
        self.use_sqrt_cost = use_sqrt_cost
        self.rng_seed = rng_seed

    def log(self, neptune_run):
        neptune_run["parameters/batch_size"] = self.batch_size
        neptune_run["parameters/clip_grad_norm"] = self.clip_grad_norm
        neptune_run["parameters/device"] = str(self.device)
        neptune_run["parameters/use_sqrt_cost"] = self.use_sqrt_cost
        neptune_run["parameters/checkpoint_last_epochs"] = self.checkpoint_last_epochs
