
import torch
from torch.nn import BCEWithLogitsLoss, MSELoss
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR

from train.pytorch_wrapper.lr_scheduler import get_linear_lr_update
from train.pytorch_wrapper.training_strategy import TrainingStrategy
from train.pytorch_wrapper.criteria import CriterionWrapper


class LogMSELoss(torch.nn.Module):
    def __init__(self):
        """ mode is either max or min depending on what you would like to achieve. """
        super(LogMSELoss, self).__init__()
        self.loss = MSELoss()

    def forward(self, x, t):
        mse_loss = self.loss(x, t)
        return torch.log(mse_loss + 1)


# get evaluation criteria
eval_criteria = dict()

# set optimization objective
criterion_disc = CriterionWrapper(BCEWithLogitsLoss(), output_key="logits", target_key="binary_actions")
criterion_cam = CriterionWrapper(LogMSELoss(), output_key="camera", target_key="camera_actions")
criterion_val = CriterionWrapper(LogMSELoss(), output_key="value", target_key="values")

criterion = [('disc_loss', 1.0, criterion_disc),
             ('cam_loss', 0.1, criterion_cam),
             ('val_loss', 0.5, criterion_val)]


def compile_training_strategy(lr=0.1, lr_schedule="step", num_epochs=300, patience=300,
                              weight_decay=0.0, batch_size=256, test_batch_size=256):
    """Compile training strategy
    """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def compile_optimizer(net):
        return torch.optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)

    def compile_scheduler(optimizer):
        if lr_schedule == "linear":
            update_lr = get_linear_lr_update(max_epochs=num_epochs, start_decay_at=num_epochs//4)
            return LambdaLR(optimizer, update_lr)
        elif lr_schedule == "step":
            return ReduceLROnPlateau(optimizer, mode="min", patience=num_epochs, factor=0.5)
        else:
            raise ValueError("Selected lr_schedule is not supported!")

    train_strategy = TrainingStrategy(num_epochs=num_epochs, criterion=criterion, eval_criteria=eval_criteria,
                                      compile_optimizer=compile_optimizer, compile_scheduler=compile_scheduler,
                                      tr_batch_size=batch_size, va_batch_size=test_batch_size,
                                      full_set_eval=False, patience=patience, device=device,
                                      augmentation_params=None, best_model_by=("va_loss_total", "min"),
                                      checkpoint_every_k=3, clip_grad_norm=0.5)

    return train_strategy
