
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, StepLR

from train.pytorch_wrapper.lr_scheduler import get_linear_lr_update
from train.pytorch_wrapper.training_strategy import TrainingStrategy
from train.pytorch_wrapper.criteria import CriterionWrapper

from train.behavioral_cloning.train_strategies.criteria import CamSoftmaxCELoss

# get evaluation criteria
eval_criteria = dict()

# set optimization objective
criterion_disc = CriterionWrapper(BCEWithLogitsLoss(), output_key="logits", target_key="binary_actions")
criterion_cam = CriterionWrapper(CamSoftmaxCELoss(), output_key="camera", target_key="camera_actions")
criterion_equip = CriterionWrapper(CrossEntropyLoss(), output_key="act_equip", target_key="act_equip")
criterion_place = CriterionWrapper(CrossEntropyLoss(), output_key="act_place", target_key="act_place")
criterion_craft = CriterionWrapper(CrossEntropyLoss(), output_key="act_craft", target_key="act_craft")
criterion_nearbyCraft = CriterionWrapper(CrossEntropyLoss(), output_key="act_nearbyCraft", target_key="act_nearbyCraft")
criterion_nearbySmelt = CriterionWrapper(CrossEntropyLoss(), output_key="act_nearbySmelt", target_key="act_nearbySmelt")

criterion = [('disc_loss', 1.0, criterion_disc),
             ('cam_loss', 1.0, criterion_cam),
             ('equip_loss', 1.0, criterion_equip),
             ('place_loss', 1.0, criterion_place),
             ('craft_loss', 1.0, criterion_craft),
             ('nearbycraft_loss', 1.0, criterion_nearbyCraft),
             ('nearbysmelt_loss', 1.0, criterion_nearbySmelt)]


def compile_training_strategy(lr=0.01, lr_schedule="step", num_epochs=100, patience=100,
                              weight_decay=0.0, batch_size=64, test_batch_size=100):
    """Compile training strategy
    """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def compile_optimizer(net):
        return torch.optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)

    def compile_scheduler(optimizer):
        if lr_schedule == "linear":
            update_lr = get_linear_lr_update(max_epochs=num_epochs, start_decay_at=num_epochs // 4)
            return LambdaLR(optimizer, update_lr)
        elif lr_schedule == "plateau":
            return ReduceLROnPlateau(optimizer, mode="min", patience=num_epochs, factor=0.5)
        elif lr_schedule == "step":
            return StepLR(optimizer, step_size=num_epochs)
        else:
            raise ValueError("Selected lr_schedule is not supported!")

    train_strategy = TrainingStrategy(num_epochs=num_epochs, criterion=criterion, eval_criteria=eval_criteria,
                                      compile_optimizer=compile_optimizer, compile_scheduler=compile_scheduler,
                                      tr_batch_size=batch_size, va_batch_size=test_batch_size,
                                      full_set_eval=False, patience=patience, device=device,
                                      augmentation_params=None, best_model_by=None,
                                      checkpoint_every_k=10)

    return train_strategy
