""" To test the fault tolerance, run this script multiple times.
"""

import logging

import torch
import torch.nn.functional as F
from torch import nn, optim

import neps


class TheModelClass(nn.Module):
    """Taken from https://pytorch.org/tutorials/beginner/saving_loading_models.html"""

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_model_and_optimizer(learning_rate):
    """Taken from https://pytorch.org/tutorials/beginner/saving_loading_models.html"""
    model = TheModelClass()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    return model, optimizer


def run_pipeline(pipeline_directory, learning_rate):
    model, optimizer = get_model_and_optimizer(learning_rate)
    checkpoint_path = pipeline_directory / "checkpoint.pth"

    # Check if there is a previous state of the model training that crashed
    if checkpoint_path.exists():
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch_already_trained = checkpoint["epoch"]
        print(f"Read in model trained for {epoch_already_trained} epochs")
    else:
        epoch_already_trained = 0

    for epoch in range(epoch_already_trained, 101):
        epoch += 1

        # Train model here ....

        # Repeatedly save your progress
        if epoch % 10 == 0:
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                checkpoint_path,
            )

        # Here we simulate a crash! E.g., due to job runtime limits
        if epoch == 50 and learning_rate < 0.2:
            print("Oh no! A simulated crash!")
            exit()

    return learning_rate  # Replace with actual error


pipeline_space = dict(
    learning_rate=neps.FloatParameter(lower=0, upper=1),
)

logging.basicConfig(level=logging.INFO)
neps.run(
    run_pipeline=run_pipeline,
    pipeline_space=pipeline_space,
    root_directory="results/fault_tolerance_example",
    max_evaluations_total=15,
)
previous_results, pending_configs = neps.status("results/fault_tolerance_example")
