"""
Mozilla Public License Version 2.0

Copyright (c) 2022 MPI-Dortmund

1. Definitions
--------------

1.1. "Contributor"
    means each individual or legal entity that creates, contributes to
    the creation of, or owns Covered Software.

1.2. "Contributor Version"
    means the combination of the Contributions of others (if any) used
    by a Contributor and that particular Contributor's Contribution.

1.3. "Contribution"
    means Covered Software of a particular Contributor.

1.4. "Covered Software"
    means Source Code Form to which the initial Contributor has attached
    the notice in Exhibit A, the Executable Form of such Source Code
    Form, and Modifications of such Source Code Form, in each case
    including portions thereof.

1.5. "Incompatible With Secondary Licenses"
    means

    (a) that the initial Contributor has attached the notice described
        in Exhibit B to the Covered Software; or

    (b) that the Covered Software was made available under the terms of
        version 1.1 or earlier of the License, but not also under the
        terms of a Secondary License.

1.6. "Executable Form"
    means any form of the work other than Source Code Form.

1.7. "Larger Work"
    means a work that combines Covered Software with other material, in
    a separate file or files, that is not Covered Software.

1.8. "License"
    means this document.

1.9. "Licensable"
    means having the right to grant, to the maximum extent possible,
    whether at the time of the initial grant or subsequently, any and
    all of the rights conveyed by this License.

1.10. "Modifications"
    means any of the following:

    (a) any file in Source Code Form that results from an addition to,
        deletion from, or modification of the contents of Covered
        Software; or

    (b) any new file in Source Code Form that contains any Covered
        Software.

1.11. "Patent Claims" of a Contributor
    means any patent claim(s), including without limitation, method,
    process, and apparatus claims, in any patent Licensable by such
    Contributor that would be infringed, but for the grant of the
    License, by the making, using, selling, offering for sale, having
    made, import, or transfer of either its Contributions or its
    Contributor Version.

1.12. "Secondary License"
    means either the GNU General Public License, Version 2.0, the GNU
    Lesser General Public License, Version 2.1, the GNU Affero General
    Public License, Version 3.0, or any later versions of those
    licenses.

1.13. "Source Code Form"
    means the form of the work preferred for making modifications.

1.14. "You" (or "Your")
    means an individual or a legal entity exercising rights under this
    License. For legal entities, "You" includes any entity that
    controls, is controlled by, or is under common control with You. For
    purposes of this definition, "control" means (a) the power, direct
    or indirect, to cause the direction or management of such entity,
    whether by contract or otherwise, or (b) ownership of more than
    fifty percent (50%) of the outstanding shares or beneficial
    ownership of such entity.

2. License Grants and Conditions
--------------------------------

2.1. Grants

Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:

(a) under intellectual property rights (other than patent or trademark)
    Licensable by such Contributor to use, reproduce, make available,
    modify, display, perform, distribute, and otherwise exploit its
    Contributions, either on an unmodified basis, with Modifications, or
    as part of a Larger Work; and

(b) under Patent Claims of such Contributor to make, use, sell, offer
    for sale, have made, import, and otherwise transfer either its
    Contributions or its Contributor Version.

2.2. Effective Date

The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.

2.3. Limitations on Grant Scope

The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:

(a) for any code that a Contributor has removed from Covered Software;
    or

(b) for infringements caused by: (i) Your and any other third party's
    modifications of Covered Software, or (ii) the combination of its
    Contributions with other software (except as part of its Contributor
    Version); or

(c) under Patent Claims infringed by Covered Software in the absence of
    its Contributions.

This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).

2.4. Subsequent Licenses

No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).

2.5. Representation

Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.

2.6. Fair Use

This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.

2.7. Conditions

Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.

3. Responsibilities
-------------------

3.1. Distribution of Source Form

All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.

3.2. Distribution of Executable Form

If You distribute Covered Software in Executable Form then:

(a) such Covered Software must also be made available in Source Code
    Form, as described in Section 3.1, and You must inform recipients of
    the Executable Form how they can obtain a copy of such Source Code
    Form by reasonable means in a timely manner, at a charge no more
    than the cost of distribution to the recipient; and

(b) You may distribute such Executable Form under the terms of this
    License, or sublicense it under different terms, provided that the
    license for the Executable Form does not attempt to limit or alter
    the recipients' rights in the Source Code Form under this License.

3.3. Distribution of a Larger Work

You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).

3.4. Notices

You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.

3.5. Application of Additional Terms

You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.

4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------

If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.

5. Termination
--------------

5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.

5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.

5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.

************************************************************************
*                                                                      *
*  6. Disclaimer of Warranty                                           *
*  -------------------------                                           *
*                                                                      *
*  Covered Software is provided under this License on an "as is"       *
*  basis, without warranty of any kind, either expressed, implied, or  *
*  statutory, including, without limitation, warranties that the       *
*  Covered Software is free of defects, merchantable, fit for a        *
*  particular purpose or non-infringing. The entire risk as to the     *
*  quality and performance of the Covered Software is with You.        *
*  Should any Covered Software prove defective in any respect, You     *
*  (not any Contributor) assume the cost of any necessary servicing,   *
*  repair, or correction. This disclaimer of warranty constitutes an   *
*  essential part of this License. No use of any Covered Software is   *
*  authorized under this License except under this disclaimer.         *
*                                                                      *
************************************************************************

************************************************************************
*                                                                      *
*  7. Limitation of Liability                                          *
*  --------------------------                                          *
*                                                                      *
*  Under no circumstances and under no legal theory, whether tort      *
*  (including negligence), contract, or otherwise, shall any           *
*  Contributor, or anyone who distributes Covered Software as          *
*  permitted above, be liable to You for any direct, indirect,         *
*  special, incidental, or consequential damages of any character      *
*  including, without limitation, damages for lost profits, loss of    *
*  goodwill, work stoppage, computer failure or malfunction, or any    *
*  and all other commercial damages or losses, even if such party      *
*  shall have been informed of the possibility of such damages. This   *
*  limitation of liability shall not apply to liability for death or   *
*  personal injury resulting from such party's negligence to the       *
*  extent applicable law prohibits such limitation. Some               *
*  jurisdictions do not allow the exclusion or limitation of           *
*  incidental or consequential damages, so this exclusion and          *
*  limitation may not apply to You.                                    *
*                                                                      *
************************************************************************

8. Litigation
-------------

Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.

9. Miscellaneous
----------------

This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.

10. Versions of the License
---------------------------

10.1. New Versions

Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.

10.2. Effect of New Versions

You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.

10.3. Modified Versions

If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).

10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses

If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.

Exhibit A - Source Code Form License Notice
-------------------------------------------

  This Source Code Form is subject to the terms of the Mozilla Public
  License, v. 2.0. If a copy of the MPL was not distributed with this
  file, You can obtain one at http://mozilla.org/MPL/2.0/.

If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.

You may add additional accurate notices of copyright ownership.

Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------

  This Source Code Form is "Incompatible With Secondary Licenses", as
  defined by the Mozilla Public License, v. 2.0.
"""

import copy
import os
from typing import Any, Dict, Tuple, Iterable

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch import optim
from torch.backends import cudnn
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import tomotwin
from tomotwin.modules.common import preprocess
from tomotwin.modules.networks.torchmodel import TorchModel
from tomotwin.modules.training.trainer import Trainer
from tomotwin.modules.training.tripletdataset import TripletDataset


class TorchTrainer(Trainer):
    """
    Trainer for pytorch.
    """

    def __init__(
        self,
        epochs: int,
        batchsize: int,
        learning_rate: float,
        network: TorchModel,
        criterion: nn.Module,
        training_data: TripletDataset = None,
        test_data: TripletDataset = None,
        workers: int = 0,
        output_path: str = None,
        log_dir: str = None,
        checkpoint: str = None,
        optimizer: str = "Adam",
        amsgrad: bool = False,
        weight_decay: float = 0,
        patience: int = None,
        save_epoch_seperately: bool = False,
    ):
        """
        :param epochs: Number of epochs
        :param batchsize: Training batch size
        :param learning_rate: The learning rate
        """

        super().__init__()
        cudnn.benchmark = True
        self.epochs = epochs
        self.batchsize = batchsize
        self.learning_rate = learning_rate
        self.training_data = training_data
        self.test_data = test_data
        self.patience = patience
        if self.patience is None:
            self.patience = self.epochs
        self.workers = workers
        self.best_model_loss = None
        self.best_model_f1 = None
        self.log_dir = log_dir
        self.writer = SummaryWriter(log_dir=self.log_dir)
        self.criterion = criterion
        self.network = network
        self.network_config = None
        self.output_path = output_path
        self.last_loss = None
        self.best_val_loss = np.Infinity

        self.best_val_f1 = 0
        self.current_epoch = None
        self.best_epoch_loss = None
        self.best_epoch_f1 = None
        self.checkpoint = None
        self.start_epoch = 0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.network.init_weights()
        self.model = self.network.get_model()
        self.checkpoint = checkpoint
        self.save_epoch_seperately = save_epoch_seperately
        self.f1_improved = False
        self.loss_improved = False

        # Write graph to tensorboard
        dummy_input = torch.zeros([12, 1, 37, 37, 37])
        self.writer.add_graph(self.model, dummy_input)

        self.model = self.model.to(self.device)
        self.optimizer = getattr(optim, optimizer)(
            self.model.parameters(),
            lr=self.learning_rate,
            amsgrad=amsgrad,
            weight_decay=weight_decay,
        )
        self.scheduler = ReduceLROnPlateau(
            self.optimizer, mode="min", patience=patience, verbose=True
        )
        model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        params = sum([np.prod(p.size()) for p in model_params])
        print("Number of parameters:", params)

        self.writer.add_text("Optimizer", type(self.optimizer).__name__)
        self.writer.add_text("Initial learning rate", str(self.learning_rate))

        if self.checkpoint is not None:
            self.load_checkpoint(checkpoint=self.checkpoint)

        self.model = nn.DataParallel(self.model)

    def set_seed(self, seed: int):
        """
        Set the seed for random number generators
        """
        torch.manual_seed(seed)
        torch.seed()

    def get_train_test_dataloader(self) -> Tuple[DataLoader, DataLoader]:
        """
        Create a dataloaders for the train and validation data
        """
        train_loader = DataLoader(
            self.training_data,
            batch_size=self.batchsize,
            shuffle=True,
            num_workers=self.workers,
            pin_memory=False,
            # prefetch_factor=5,
            timeout=180,
        )

        test_loader = None
        if self.test_data is not None:
            test_loader = DataLoader(
                self.test_data,
                batch_size=self.batchsize,
                shuffle=True,
                num_workers=self.workers,
                pin_memory=False,
                # prefetch_factor=5,
                timeout=60,
            )
        return train_loader, test_loader

    @staticmethod
    def get_best_f1(
        anchor_label: str, similarities: np.array, sim_labels: Iterable
    ) -> Tuple[float, float]:
        """
        Caluclate the classification F1 score for a given anchor
        """
        PDB = os.path.splitext(anchor_label)[0].upper()
        gt_mask = np.array([PDB in p.upper() for p in sim_labels])
        best_f1 = 0
        best_t = None
        for t in np.arange(0, 1, 0.025):
            picked = similarities > t

            true_positive = np.logical_and(gt_mask, picked)
            TP = np.sum(true_positive)
            false_positive = np.logical_and(gt_mask == False, picked)
            FP = np.sum(false_positive)
            false_negative = np.logical_and(gt_mask, picked == False)
            FN = np.sum(false_negative)
            f1 = 2 * TP / (2 * TP + FP + FN)
            if f1 >= best_f1:
                best_t = t
                best_f1 = f1

        return best_f1, best_t

    @staticmethod
    def calc_avg_f1(anchors: pd.DataFrame, volumes: pd.DataFrame) -> float:
        """
        Calculates average f1 score
        Each column in 'anchors' represents an anchor volume.
        Each column in 'volumes' represents an tomogram subvolume
        :return: Classification accuracy
        """
        scores = []
        for col in anchors:
            sim = np.matmul(volumes.T, anchors[col])
            best_f1, _ = TorchTrainer.get_best_f1(
                anchor_label=col, similarities=sim, sim_labels=sim.index.values
            )
            scores.append(best_f1)
        avg_f1 = np.mean(scores)
        return avg_f1

    def classification_f1_score(self, test_loader: DataLoader) -> float:
        """
        Calculates classification f1 score
        :return: F1 score
        """
        self.model.eval()
        t = tqdm(test_loader, desc="Classification accuracy", leave=False)
        anchor_emb = {}  # pd.DataFrame()
        vol_emb = {}  # pd.DataFrame()

        with torch.no_grad():
            for _, batch in enumerate(t):
                anchor_vol = batch["anchor"].to(self.device, non_blocking=True)
                positive_vol = batch["positive"].to(self.device, non_blocking=True)
                negative_vol = batch["negative"].to(self.device, non_blocking=True)
                filenames = batch["filenames"]
                with autocast():
                    # TODO: Probably concat anchor, positive and vol into one batch and run only one forward pass is enough.
                    anchor_out = self.model.forward(anchor_vol)
                    positive_out = self.model.forward(positive_vol)
                    negative_out = self.model.forward(negative_vol)

                    anchor_out_np = anchor_out.cpu().detach().numpy()
                    for i, anchor_filename in enumerate(filenames[0]):
                        if preprocess.label_filename(anchor_filename) not in anchor_emb:
                            anchor_emb[
                                preprocess.label_filename(anchor_filename)
                            ] = anchor_out_np[i, :]
                    positive_out_np = positive_out.cpu().detach().numpy()
                    for i, pos_filename in enumerate(filenames[1]):
                        if os.path.basename(pos_filename) not in vol_emb:
                            vol_emb[os.path.basename(pos_filename)] = positive_out_np[
                                i, :
                            ]

                    negative_out_np = negative_out.cpu().detach().numpy()
                    for i, neg_filename in enumerate(filenames[2]):
                        if os.path.basename(neg_filename) not in vol_emb:
                            vol_emb[os.path.basename(neg_filename)] = negative_out_np[
                                i, :
                            ]

        return TorchTrainer.calc_avg_f1(pd.DataFrame(anchor_emb), pd.DataFrame(vol_emb))

    def run_batch(self, batch: Dict):
        """
        Run inference on one batch.
        :param batch: Dictionary with batch data
        :return: Loss of the batch
        """
        anchor_vol = batch["anchor"].to(self.device, non_blocking=True)
        positive_vol = batch["positive"].to(self.device, non_blocking=True)
        negative_vol = batch["negative"].to(self.device, non_blocking=True)
        with autocast():
            # TODO: Probably concat anchor, positive and vol into one batch and run only on forward pass is enough.
            anchor_out = self.model.forward(anchor_vol)
            positive_out = self.model.forward(positive_vol)
            negative_out = self.model.forward(negative_vol)

            loss = self.criterion(
                anchor_out,
                positive_out,
                negative_out,
                label_anchor=batch["label_anchor"],
                label_positive=batch["label_positive"],
                label_negative=batch["label_negative"],
            )
        return loss

    def save_best_loss(self, current_val_loss: float, epoch: int) -> None:
        """
        Update best model according loss
        :param current_val_loss: Current validation loss
        :param epoch: Current epoch
        :return:  None
        """
        if current_val_loss < self.best_val_loss:
            self.loss_improved = True
            print(
                f"Validation loss improved from {self.best_val_loss} to {current_val_loss}"
            )
            self.best_epoch_loss = epoch
            self.best_val_loss = current_val_loss
            self.best_model_loss = copy.deepcopy(self.model)

    def save_best_f1(self, current_val_f1: float, epoch: int) -> None:
        """
        Update best model according f1 one score
        :param current_val_f1: Current f1 score
        :param epoch: Current epoch
        :return: None
        """
        if current_val_f1 > self.best_val_f1:
            self.f1_improved = True
            print(
                f"Validation F1 score improved from {self.best_val_f1} to {current_val_f1}"
            )
            self.best_epoch_f1 = epoch
            self.best_val_f1 = current_val_f1
            self.best_model_f1 = copy.deepcopy(self.model)

    def validation_loss(self, test_loader: DataLoader) -> float:
        """
        Runs the current model on the validation data
        :return: Validation loss
        """
        val_loss = []
        self.model.eval()
        t = tqdm(test_loader, desc="Validation", leave=False)

        with torch.no_grad():
            for _, batch in enumerate(t):
                valloss = self.run_batch(batch)
                val_loss.append(valloss.cpu().detach().numpy())
                desc_t = f"Validation (running loss: {np.mean(val_loss[-20:]):.4f} "
                t.set_description(desc=desc_t)

        current_val_loss = np.mean(val_loss)
        return current_val_loss

    def load_checkpoint(self, checkpoint: str) -> None:
        """
        Load model checkpoint
        :param checkpoint: Path to checkpoint
        :return: None
        """

        try:
            self.checkpoint = torch.load(checkpoint)
        except FileNotFoundError:
            print(f"Checkpoint {checkpoint} can't be found. Ignore it.")
            self.checkpoint = None
            return

        self.model.load_state_dict(self.checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(self.checkpoint["optimizer_state_dict"])
        self.start_epoch = self.checkpoint["epoch"] + 1
        self.last_loss = self.checkpoint["loss"]
        self.best_val_loss = self.checkpoint["best_loss"]
        self.best_val_f1 = self.checkpoint["best_f1"]
        print(
            f"Restart from checkpoint. Epoch: {self.start_epoch}, Training loss: {self.last_loss}, Validation loss: {self.best_val_loss}"
        )

    def epoch(self, train_loader: DataLoader) -> float:
        """
        Runs a single epoch
        :param train_loader: Data loader for training data
        :return: Training loss after the epoch
        """

        scaler = GradScaler()
        running_loss = []
        self.model.train()
        t = tqdm(train_loader, desc="Training", leave=False)
        for _, batch in enumerate(t):
            self.optimizer.zero_grad()

            loss = self.run_batch(batch)
            loss_np = loss.cpu().detach().numpy()
            running_loss.append(loss_np)
            scaler.scale(loss).backward()
            scaler.step(self.optimizer)
            scaler.update()
            desc_t = f"Training (loss: {np.mean(running_loss[-20:]):.4f}) "

            t.set_description(desc=desc_t)

        training_loss = np.mean(running_loss)
        self.last_loss = training_loss
        return training_loss

    def train(self) -> nn.Module:
        """
        Trains the model and returns it.
        :return: Trained model
        """
        if self.training_data is None:
            raise RuntimeError("Training data is not set")

        train_loader, test_loader = self.get_train_test_dataloader()

        # Training Loop
        for epoch in tqdm(
            range(self.start_epoch, self.epochs),
            initial=self.start_epoch,
            total=self.epochs,
            desc="Epochs",
        ):
            self.f1_improved = False
            self.loss_improved = False
            self.current_epoch = epoch
            train_loss = self.epoch(train_loader=train_loader)

            print(f"Epoch: {epoch + 1}/{self.epochs} - Training Loss: {train_loss:.4f}")
            self.writer.add_scalar("Loss/train", train_loss, epoch)

            # Validation
            if test_loader is not None:
                current_val_loss = self.validation_loss(test_loader)
                current_val_f1 = self.classification_f1_score(test_loader=test_loader)
                self.scheduler.step(current_val_loss)
                print(f"Validation Loss: {current_val_loss:.4f}.")
                print(f"Validation F1 Score: {current_val_f1:.4f}.")
                self.writer.add_scalar("Loss/validation", current_val_loss, epoch)
                self.writer.add_scalar("F1/validation", current_val_f1, epoch)
                self.save_best_loss(current_val_loss, epoch)
                self.save_best_f1(current_val_f1, epoch)

            self.writer.flush()

            if self.output_path is not None:
                self.write_results_to_disk(
                    self.output_path, save_each_improvement=self.save_epoch_seperately
                )

        return self.model

    def set_training_data(self, training_data: TripletDataset) -> None:
        """
        Set the training data
        """
        self.training_data = training_data

    def set_test_data(self, test_data: TripletDataset) -> None:
        """
        Set test (validation) data.
        """
        self.test_data = test_data

    def set_network_config(self, config) -> None:
        """
        Set the network config
        """
        self.network_config = config

    @staticmethod
    def _write_model(
        path: str,
        model: TorchModel,
        config: Dict,
        optimizer=None,
        loss: float = None,
        epoch: int = None,
        best_loss: float = None,
        best_f1: float = None,
        **kwargs,
    ):
        """
        Adds some metadata to the model and write the model  to disk

        :param path: Path where the model should be written
        :param model: The model that is saved to disk
        :param config: Configuration of tomotwin
        :param optimizer: Optimizer
        :param loss: Loss
        :param epoch: Current epoch
        :param best_loss: Current best validation loss
        :param best_f1:  Current best validtion f1 score
        :return:
        """
        for key, value in kwargs.items():
            config[key] = value
        results_dict = {
            "model_state_dict": model.state_dict(),
            "tomotwin_config": config,
            "tt_version_train": tomotwin.__version__,
        }
        if optimizer is not None:
            results_dict["optimizer_state_dict"] = optimizer.state_dict()

        if loss is not None:
            results_dict["loss"] = loss

        if best_loss is not None:
            results_dict["best_loss"] = best_loss

        if best_f1 is not None:
            results_dict["best_f1"] = best_f1

        if epoch is not None:
            results_dict["epoch"] = epoch

        torch.save(
            results_dict,
            path,
        )

    def write_model_to_disk(
        self, path: str, model_to_save, model_name: str, epoch: int, **kwargs
    ):
        """
        :param path: Path for folder where the model is saved to.
        :param model_to_save:  Model to save
        :param model_name: model filename
        :param epoch: Epoch of the model
        :return: None
        """
        if isinstance(model_to_save, nn.DataParallel):
            model_to_save = model_to_save.module

        self._write_model(
            path=os.path.join(path, model_name),
            model=model_to_save,
            config=self.network_config,
            optimizer=self.optimizer,
            loss=self.last_loss,
            best_loss=self.best_val_loss,
            best_f1=self.best_val_f1,
            epoch=epoch,
            **kwargs,
        )

    def write_results_to_disk(
        self, path: str, save_each_improvement: bool = False, **kwargs
    ) -> None:
        """
        Write the training results to specified folder
        :param path: Path to folder to write the data
        :param save_each_improvement: If true, model for each epoch is saved.
        :param kwargs:
        :return: None
        """
        self.write_model_to_disk(path, self.model, "latest.pth", self.current_epoch)

        if self.current_epoch == self.epochs - 1:
            if os.path.exists(os.path.join(path, "final.pth")):
                os.remove(os.path.join(path, "final.pth"))
            os.rename(os.path.join(path, "latest.pth"), os.path.join(path, "final.pth"))

        if self.best_model_loss is not None:
            # The best_model can be None, after a training restart.
            self.write_model_to_disk(
                path,
                self.best_model_loss,
                "best_loss.pth",
                self.best_epoch_loss,
                **kwargs,
            )

        if self.best_model_f1 is not None:
            # The best_model can be None, after a training restart.
            self.write_model_to_disk(
                path, self.best_model_f1, "best_f1.pth", self.best_epoch_f1, **kwargs
            )

        if save_each_improvement:
            mod = self.model
            ep = self.current_epoch
            add = "" + "_f1" if self.f1_improved else ""
            add = add + "_loss" if self.loss_improved else ""

            self.write_model_to_disk(
                path,
                mod,
                "best_model_" + f"{ep + 1}".zfill(3) + f"{add}.pth",
                ep,
                **kwargs,
            )

    def get_model(self) -> Any:
        """
        :return: Trained model
        """
        return self.model
