
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR

from train.pytorch_wrapper.lr_scheduler import get_linear_lr_update
from train.pytorch_wrapper.training_strategy import TrainingStrategy
from train.pytorch_wrapper.criteria import CriterionWrapper

from train.behavioral_cloning.train_strategies.criteria import CamSoftmaxCELoss

# get evaluation criteria
eval_criteria = dict()

# set optimization objective
criterion_disc = CriterionWrapper(BCEWithLogitsLoss(), output_key="logits", target_key="binary_actions")
criterion_cam = CriterionWrapper(CamSoftmaxCELoss(), output_key="camera", target_key="camera_actions")

criterion = [('disc_loss', 1.0, criterion_disc),
             ('cam_loss', 1.0, criterion_cam)]


def compile_training_strategy(lr=0.01, lr_schedule="linear", num_epochs=80, patience=80,
                              weight_decay=1e-4, batch_size=100, test_batch_size=100):
    """Compile training strategy
    """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def compile_optimizer(net):
        return torch.optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)

    def compile_scheduler(optimizer):
        if lr_schedule == "linear":
            update_lr = get_linear_lr_update(max_epochs=num_epochs, start_decay_at=num_epochs//4)
            return LambdaLR(optimizer, update_lr)
        elif lr_schedule == "step":
            return ReduceLROnPlateau(optimizer, mode="min", patience=num_epochs, factor=0.5)
        else:
            raise ValueError("Selected lr_schedule is not supported!")

    train_strategy = TrainingStrategy(num_epochs=num_epochs, criterion=criterion, eval_criteria=eval_criteria,
                                      compile_optimizer=compile_optimizer, compile_scheduler=compile_scheduler,
                                      tr_batch_size=batch_size, va_batch_size=test_batch_size,
                                      full_set_eval=False, patience=patience, device=device,
                                      augmentation_params=None, best_model_by=("va_loss_total", "min"),
                                      checkpoint_every_k=3)

    return train_strategy
