import torch
import torch.optim as optim
from train.pytorch_wrapper.lr_scheduler import ConstantLR

available_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def compile_optimizer(net):
    return optim.Adam(net.parameters(), lr=0.001)


def compile_default_scheduler(optimizer):
    return ConstantLR(optimizer)


class TrainingStrategy(object):

    def __init__(self, num_epochs, criterion, tr_batch_size=100, va_batch_size=100, eval_criteria=dict(),
                 compile_optimizer=compile_optimizer, compile_scheduler=compile_default_scheduler,
                 best_model_by=("va_loss", "min"), full_set_eval=False, device=None, patience=None,
                 augmentation_params={}, clip_grad_norm=None, checkpoint_every_k=None, load_checkpoint: str = None):
        """
        Constructor
        """
        self.num_epochs = num_epochs
        self.criterion = criterion if isinstance(criterion, list) else [('loss', 1.0, criterion)]
        self.compile_optimizer = compile_optimizer
        self.optimizer = None
        self.compile_scheduler = compile_scheduler
        self.lr_scheduler = None
        self.eval_criteria = eval_criteria
        self.best_model_by = best_model_by
        self.full_set_eval = full_set_eval
        self.device = available_device if device is None else device
        self.patience = patience
        self.tr_batch_size = tr_batch_size
        self.va_batch_size = va_batch_size
        self.augmentation_params = augmentation_params
        self.clip_grad_norm = clip_grad_norm
        self.load_checkpoint = load_checkpoint
        self.checkpoint_every_k = checkpoint_every_k

    def register_optimizer(self, net):
        # init optimizer
        self.optimizer = self.compile_optimizer(net)
        self.lr_scheduler = self.compile_scheduler(self.optimizer)
