from spaghettini import quick_register
from abc import ABC
import copy
import traceback

import numpy as np
import torch
import pytorch_lightning as pl

from src.utils.misc import prepend_string_to_dict_keys
from src.dl.models.wrappers.proximal_reg_wrapper import ProximalRegWrapper
from src.learners.probe_training.probe_trainer import ProbeTrainer
from src.learners.verifier_attacing.verifier_attacker import VerifierAttacker
from src.learners.prover_pretraining.prover_pretrainer import ProverPretrainer

SMALL = 1e-8


@quick_register
class LJNSingleBase(pl.LightningModule, ProbeTrainer, VerifierAttacker, ProverPretrainer, ABC):
    def __init__(self, verifier, prover, train_loader, verifier_optimizer_init,
                 prover_optimizer_init, classification_loss_fn, use_matching_verifier_loss, nm_loss_weighting,
                 prover_aux_losses, verifier_aux_losses, label_0_weighting, label_1_weighting,
                 verifier_scheduler, prover_scheduler, num_verifier_steps, num_prover_steps,
                 proximal_reg,
                 give_correct_proof_for_n_game_steps, block_prover_output_for_n_game_steps, prover_mode, l2_proof_reg,
                 do_inspection_every_n_batches, probe_specs, probe_optimizer_init, probe_dataloader,
                 max_probe_training_batches, track_last_n_batch_probe_outputs,
                 v_attack_p_optimizer_init, max_v_attack_p_training_batches,
                 log_forward_metrics_every_n_game_steps,
                 log_forward_plots_every_n_game_steps, v_attack_proof_optim_specs=None,
                 fixed_minibatch_getter=None, pretraining_specs=None,
                 plotting_fns=None, lookahead=None, perform_verifier_attack=True,
                 adaptive_prover_step_specs=None, label_flipping=None, collapse_preventing_reg=None,
                 max_num_batches_in_buffer=75, **kwargs):
        super().__init__()
        assert prover_mode in ["collaborative", "ljn"], print(f"Prover mode {prover_mode} not recognized.")

        self.verifier = verifier
        self.prover = prover

        # Data loaders.
        self.train_loader = train_loader

        # Optimizers and learning rate schedules.
        self.v_opt_init = verifier_optimizer_init
        self.p_opt_init = prover_optimizer_init
        self.lookahead = lookahead

        self.verifier_scheduler_init = verifier_scheduler
        self.prover_scheduler_init = prover_scheduler

        self.num_verifier_steps = num_verifier_steps
        self.num_prover_steps = num_prover_steps
        self.adaptive_prover_step_specs = adaptive_prover_step_specs

        # Loss functions and coefficients.
        self.classification_loss_fn = classification_loss_fn
        self.use_matching_verifier_loss = use_matching_verifier_loss
        self.nm_loss_weighting = nm_loss_weighting
        self.prover_aux_losses = prover_aux_losses
        self.verifier_aux_losses = verifier_aux_losses
        self.label_0_weighting = label_0_weighting
        self.label_1_weighting = label_1_weighting
        self.label_flipping = label_flipping

        # Additional regularization terms.
        self.prox_reg = proximal_reg
        if proximal_reg is not None:
            assert isinstance(prover, ProximalRegWrapper) and isinstance(verifier, ProximalRegWrapper)
        self.l2_proof_reg = l2_proof_reg
        self.collapse_preventing_reg = collapse_preventing_reg

        # Training setup.
        self.give_correct_proof_for_n_game_steps = give_correct_proof_for_n_game_steps
        self.block_prover_output_for_n_game_steps = block_prover_output_for_n_game_steps
        self.prover_mode = prover_mode

        # Pretraining related.
        self.pretraining_specs = pretraining_specs

        # Probe related.
        self.probe_specs = probe_specs
        self.probe_optimizer_init = probe_optimizer_init
        self.probe_dataloader = probe_dataloader
        self.max_probe_training_batches = max_probe_training_batches
        self.max_num_batches_in_buffer = max_num_batches_in_buffer
        self.track_last_n_batch_probe_outputs = track_last_n_batch_probe_outputs

        # Verifier attack related.
        self.v_attack_p_optimizer_init = v_attack_p_optimizer_init
        self.max_v_attack_p_training_batches = max_v_attack_p_training_batches
        self.v_attack_proof_optim_specs = v_attack_proof_optim_specs
        self.perform_verifier_attack = perform_verifier_attack

        # Plotting (for logging purposes) related.
        self.plotting_fns = plotting_fns

        # Inspection minibatch.
        self.fixed_minibatch = fixed_minibatch_getter(train_loader=train_loader(),
                                                      unpack_fn=self.unpack_data_batch) if \
            fixed_minibatch_getter is not None else None

        # Periodic stuff.
        self.log_forward_metrics_every_n_game_steps = log_forward_metrics_every_n_game_steps
        self.log_forward_plots_every_n_game_steps = log_forward_plots_every_n_game_steps
        self.do_inspection_every_n_batches = do_inspection_every_n_batches

        # Additional states, including feature buffer (for probe training).
        self.feat_buffer_dict = dict()  # TODO: Rename this to be the "probe buffer dict".
        self.state_buffer_dict = dict()

    @property
    def game_step_length(self):
        return sum(self.trainer.optimizer_frequencies)

    @property
    def game_step(self):
        return self.global_step // self.game_step_length

    def on_train_start(self):
        if self.pretraining_specs is not None and self.global_step == 0:
            self.run_prover_pretraining()

    # ____ Forward pass. ____
    def channel_fn(self, proofs, inputs):
        return proofs

    def run_prover(self, p_inputs, prover, correct_proofs, other_data, **kwargs):
        # ____Generate all the prover outputs. ____
        prover_outputs, prover_aux, prover_dict = prover(p_inputs, **other_data)

        return prover_outputs, prover_aux, prover_dict

    def process_prover_outputs(self, p_inputs, prover_outputs, correct_proofs, **kwargs):
        # Enable computing the gradients of the proof vectors. Useful for logging, can ignore otherwise.
        if "retain_prover_output_grads" in kwargs.keys():
            if kwargs["retain_prover_output_grads"] is True:
                prover_outputs = prover_outputs.clone().detach()  # TODO: This might be avoidable.
                prover_outputs.requires_grad = True
                prover_outputs.retain_grad()

        # ____ Block the prover outputs if asked to do so. ____
        if self.current_epoch < self.block_prover_output_for_n_game_steps:
            prover_outputs = self.block_prover_outputs(curr_prover_outputs=prover_outputs)

        # ____ Apply the channel function (or constraint). This will usually be the identity mapping. ____
        proofs = self.channel_fn(proofs=prover_outputs, inputs=p_inputs)

        # ____ Use the (hardcoded) correct proofs if asked to do so. ____
        if self.current_epoch < self.give_correct_proof_for_n_game_steps:
            proofs = self.use_correct_proofs(curr_proofs=proofs, correct_proofs=correct_proofs, inputs=p_inputs)

        return proofs, prover_outputs

    def run_verifier(self, proofs, v_inputs, verifier, other_data, **kwargs):
        preds, verifier_aux, verifier_dict = verifier(v_inputs, [proofs], **other_data)
        return preds, verifier_aux, verifier_dict

    # ____ Forward Pass. ____
    def forward(self, p_inputs, v_inputs, prover, verifier, correct_proofs, other_data, **kwargs):
        prover_outputs, prover_aux, prover_dict = self.run_prover(p_inputs=p_inputs, prover=prover,
                                                                  correct_proofs=correct_proofs,
                                                                  other_data=other_data,
                                                                  **kwargs)

        proofs, prover_outputs = self.process_prover_outputs(p_inputs=p_inputs, prover_outputs=prover_outputs,
                                                             correct_proofs=correct_proofs, **kwargs)

        preds, verifier_aux, verifier_dict = self.run_verifier(proofs=proofs, v_inputs=v_inputs, verifier=verifier,
                                                               other_data=other_data, **kwargs)

        model_dict = dict()
        model_dict.update(verifier_dict)
        model_dict.update(prover_dict)
        model_dict.update(dict(proofs=proofs))

        return preds, proofs, prover_outputs, prover_aux, verifier_aux, model_dict

    def training_step(self, batch, batch_idx, optimizer_idx, *args, **kwargs):
        # print(f"In: training, Game_step: {self.game_step}, global_step: {self.global_step}, batch_idx: {batch_idx}, optimizer_idx: {optimizer_idx}")

        # ____ Unpack data batch. ____
        # Get the inputs and ground truth labels and other data.
        p_xs, v_xs, ys_true, other_data = self.unpack_data_batch(batch)

        # ____ Deal with operations that need to be done periodically. ____
        # If using proximal regularization, synchronize the current and trailing model copies once in a while.
        if self.prox_reg is not None:
            self.prox_reg_sync()

        # ____ Run inspection. Includes things like 1) probe training 2) freezing verifier and training prover. ____
        self.run_inspection()

        # ____ Run forward pass. ____
        if optimizer_idx == 0:  # Training verifier.
            ys_train = ys_true
            loss, forward_metrics, logs_dict = self.common_step(p_xs=p_xs, v_xs=v_xs, ys_train=ys_train,
                                                                ys_true=ys_true, other_data=other_data, net_idx=0,
                                                                batch_nb=batch_idx, prepend_key="training/")
        elif optimizer_idx == 1:  # Training prover defending class 0.
            ys_train = self.get_prover_ys(ys_true=ys_true, prover_mode=self.prover_mode)
            loss, forward_metrics, logs_dict = self.common_step(p_xs=p_xs, v_xs=v_xs, ys_train=ys_train,
                                                                ys_true=ys_true, other_data=other_data, net_idx=1,
                                                                batch_nb=batch_idx, prepend_key="training/")
        else:
            loss, forward_metrics, logs_dict = None, None, None
            print(f"Invalid optimizer idx {optimizer_idx}. Aborting")
            exit(-1)

        # ____ Log forward metrics and plots. ____
        if (self.global_step % (self.log_forward_metrics_every_n_game_steps * self.game_step_length) in
                [0, self.num_verifier_steps]):
            self.logger.experiment.log(forward_metrics)

        if (self.global_step % (self.log_forward_plots_every_n_game_steps * self.game_step_length) in
                [0, self.num_verifier_steps]):
            self.log_forward_plots(forward_metrics, logs_dict)

        # ____ Adapt prover strength, if asked. ____
        if self.adaptive_prover_step_specs is not None:
            self.adapt_prover_strength(forward_metrics)

        return {"loss": loss, "logs_dict": logs_dict}

    # ____ Common step. ____
    def common_step(self, p_xs, v_xs, ys_train, ys_true, other_data, net_idx, batch_nb, prepend_key=""):
        # Update the prepend key.
        prepend_key = prepend_key + f"t_{net_idx}/"

        # ____ Perform forward passes. ____
        # Run forward pass using matching proof-verifier input pairs.
        correct_proofs = other_data["correct_proofs"] if "correct_proofs" in other_data else None
        preds, proofs, p_outs, p_aux_dict, v_aux_dict, model_dict = self.forward(p_inputs=p_xs, v_inputs=v_xs,
                                                                                 prover=self.prover,
                                                                                 verifier=self.verifier,
                                                                                 correct_proofs=correct_proofs,
                                                                                 other_data=other_data)

        # If asked, run forward pass using non-matching proof-verifier input pairs.
        if net_idx == 0 and (self.use_matching_verifier_loss or "proof_input_matching" in self.verifier_aux_losses):
            # Run the verifier with non-matching input-proof pairs. "nm" stands for non-matching.
            (p_xs_nm, v_xs_nm, correct_proofs_nm, p_xs_ys_true,
             v_xs_ys_true, ys_train_nm, other_data_nm) = self.get_nonmatching_examples(p_xs=p_xs, v_xs=v_xs,
                                                                                       ys_true=ys_true,
                                                                                       correct_proofs=correct_proofs,
                                                                                       other_data=other_data)
            (preds_nm, proofs_nm, p_outs_nm,
             p_aux_dict_nm, v_aux_dict_nm, model_dict_nm) = self.forward(p_inputs=p_xs_nm, v_inputs=v_xs_nm,
                                                                         prover=self.prover, verifier=self.verifier,
                                                                         correct_proofs=correct_proofs_nm,
                                                                         other_data=other_data_nm)
        else:
            v_aux_dict_nm = dict()
            p_xs_nm, v_xs_nm, proofs_nm, correct_proofs_nm, ys_train_nm, preds_nm = None, None, None, None, None, None
            p_xs_ys_true, v_xs_ys_true = None, None

        # ____ Compute losses. ____
        # If asked, apply a bit of label flipping when training the prover.
        if self.label_flipping is not None and net_idx == 1:
            flip = self.label_flipping.flip_probability * torch.ones_like(ys_true)
            flip = torch.bernoulli(flip)
            ys_train = torch.logical_xor(ys_train, flip).long()

        # Compute task loss. Decompose the 0-input and 1-input components separately.
        idx_0, idx_1 = (ys_true == torch.zeros_like(ys_true)), (ys_true == torch.ones_like(ys_true))
        cls_loss_0 = self.classification_loss_fn(preds[idx_0], ys_train[idx_0]) if idx_0.sum() > 0 else 0
        cls_loss_1 = self.classification_loss_fn(preds[idx_1], ys_train[idx_1]) if idx_1.sum() > 0 else 0

        # If asked, train the verifier network on input-proof matching loss.
        nm_loss = 0.
        if net_idx == 0 and self.use_matching_verifier_loss:
            # The answers on non-matching samples must always be "reject", or "1".
            nm_loss = nm_loss + self.classification_loss_fn(preds_nm, ys_train_nm)

        # Compute the total classification loss.
        if net_idx == 0 and self.use_matching_verifier_loss:
            assert self.label_0_weighting == 1. and self.label_1_weighting == 1.
            assert 0. <= self.nm_loss_weighting <= 1., f"Sum of matching and non-matching losses must be 1. "
            matching_term_weighting = (1. - self.nm_loss_weighting) / 2
            cls_loss = matching_term_weighting * (cls_loss_0 + cls_loss_1) + self.nm_loss_weighting * nm_loss
        else:
            # No impact of the matching loss. Directly compute the loss.
            cls_loss = (1 / 2) * self.label_0_weighting * cls_loss_0 + (1 / 2) * self.label_1_weighting * cls_loss_1

        # Compute the prover auxiliary losses.
        p_aux_loss, p_aux_losses_dict = 0., dict()
        if net_idx == 1:
            p_aux_losses_dict = self.compute_prover_aux_losses(p_aux_dict=p_aux_dict,
                                                               p_xs=p_xs, ys_true=ys_true,
                                                               other_data=other_data)
            p_aux_loss = p_aux_loss + sum([value for name, value in p_aux_losses_dict.items()])

        # Compute the verifier auxiliary losses.
        v_aux_loss, v_aux_losses_dict = 0., dict()
        if net_idx == 0:
            v_aux_losses_dict = self.compute_verifier_aux_losses(v_aux_dict=v_aux_dict, v_aux_dict_nm=v_aux_dict_nm,
                                                                 ys_true=ys_true)
            v_aux_loss = v_aux_loss + sum([value for name, value in v_aux_losses_dict.items()])

        # Compute additional regularization losses.
        additional_reg_term = 0.
        if self.prox_reg is not None:
            # Compute forward pass using the trailing model weights.
            preds_prox, proofs_prox, p_outs_prox, _, _, _ = self.forward(p_inputs=p_xs, v_inputs=v_xs,
                                                                         prover=self.prover.trailing_model,
                                                                         verifier=self.verifier.trailing_model,
                                                                         correct_proofs=correct_proofs,
                                                                         other_data=other_data)
            verifier_prox_term = self.prox_reg.verifier_coeff * self.prox_reg.distance_fn(preds, preds_prox)
            prover_prox_term = self.prox_reg.prover_coeff * self.prox_reg.distance_fn(p_outs, p_outs_prox)
            additional_reg_term = additional_reg_term + verifier_prox_term + prover_prox_term

        if self.l2_proof_reg is not None:
            # Penalize the L2 norm of the proofs.
            l2_proof_reg = self.l2_proof_reg.coeff * (p_outs ** 2).mean()
            additional_reg_term = additional_reg_term + l2_proof_reg

        if self.collapse_preventing_reg is not None and net_idx == 1:
            # Add collapse preventing loss. Apply to prover features as well as proofs.
            max_cp_samples = self.collapse_preventing_reg.max_samples
            feats = model_dict["prover_feats"]
            cpl_on_feats = self.collapse_preventing_reg.reg_fn(feats=feats[:max_cp_samples],
                                                               targets=ys_true[:max_cp_samples])
            cpl_on_proofs = self.collapse_preventing_reg.reg_fn(feats=proofs[:max_cp_samples],
                                                                targets=ys_true[:max_cp_samples])
            additional_reg_term = additional_reg_term + cpl_on_feats + cpl_on_proofs

        # Compute total loss.
        total_loss = cls_loss + p_aux_loss + v_aux_loss + additional_reg_term

        # ____ Log relevant data. ____
        causal_features = other_data["causal_features"] if "causal_features" in other_data else None
        try:
            prover_clone = copy.deepcopy(self.prover)
            verifier_clone = copy.deepcopy(self.verifier)
        except:
            print("Failed to copy prover and verifier using deepcopy. Not copying. PROCEED WITH CAUTION.")
            prover_clone = self.prover
            verifier_clone = self.verifier
        logs_dict = dict(p_xs=p_xs, v_xs=v_xs, ys_train=ys_train,
                         ys_true=ys_true, correct_proofs=correct_proofs,
                         causal_features=causal_features,
                         p_xs_nm=p_xs_nm, v_xs_nm=v_xs_nm,
                         ys_train_nm=ys_train_nm, nm_p_xs_labels=p_xs_ys_true,
                         nm_v_xs_labels=v_xs_ys_true, preds_nm=preds_nm,
                         other_data=other_data, preds=preds, p_outs=p_outs, proofs=proofs,
                         p_aux_dict=p_aux_dict,
                         p_aux_losses_dict=p_aux_losses_dict,
                         v_aux_dict=v_aux_dict,
                         v_aux_dict_nm=v_aux_dict_nm,
                         v_aux_losses_dict=v_aux_losses_dict,
                         classification_loss=cls_loss,
                         cls_loss_0=cls_loss_0, cls_loss_1=cls_loss_1,
                         label_0_idxs=idx_0, label_1_idxs=idx_1,
                         model_logs=model_dict,
                         verifier=verifier_clone,
                         prover=prover_clone,
                         classification_loss_fn=self.classification_loss_fn,
                         forward_fn=self.forward,
                         net_idx=net_idx,
                         batch_nb=batch_nb, global_step=self.global_step, game_step=self.game_step,
                         current_epoch=self.current_epoch,
                         prepend_key=prepend_key)
        forward_metrics = self.prepare_forward_metrics(**logs_dict)
        forward_metrics = self.task_specific_logging(metric_logs=forward_metrics, logs_dict=logs_dict, **logs_dict)
        logs_dict = prepend_string_to_dict_keys(prepend_key=prepend_key, dictinary=logs_dict)

        # ____ Update the feature buffer (to be used for probe training.) ____
        self.update_feat_buffer(model_dict, **logs_dict)

        return total_loss, forward_metrics, logs_dict

    def compute_prover_aux_losses(self, p_aux_dict, p_xs, ys_true, other_data):
        p_aux_losses_dict = dict()
        for name, p_aux_out in p_aux_dict.items():
            if name in self.prover_aux_losses.keys():
                aux_loss = self.prover_aux_losses[name].loss_fn(p_aux_out=p_aux_out, p_xs=p_xs, ys_true=ys_true,
                                                                other_data=other_data)
                p_aux_losses_dict[name] = self.prover_aux_losses[name].coeff * aux_loss

        return p_aux_losses_dict

    def compute_verifier_aux_losses(self, **kwargs):
        v_aux_dict = kwargs["v_aux_dict"]
        v_aux_losses_dict = dict()
        for name, v_aux_out in v_aux_dict.items():
            if name in self.verifier_aux_losses.keys():
                aux_loss = self.verifier_aux_losses[name].loss_fn(**kwargs)
                v_aux_losses_dict[name] = self.verifier_aux_losses[name].coeff * aux_loss

        return v_aux_losses_dict

    # ____ Backward Pass related. ____
    def configure_optimizers(self):
        if self.prox_reg is None:
            v_opt = self.v_opt_init(self.verifier.parameters())
            p_opt = self.p_opt_init(self.prover.parameters())

        else:
            v_opt = self.v_opt_init(self.verifier.curr_model.parameters())
            p_opt = self.p_opt_init(self.prover.curr_model.parameters())

        # Use the lookahead optimizer if asked.
        if self.lookahead is not None:
            v_opt = self.lookahead(v_opt)
            p_opt = self.lookahead(p_opt)

        # Use a learning rate scheduler if asked.
        if self.verifier_scheduler_init is not None and self.prover_scheduler_init is not None:
            v_sch = self.verifier_scheduler_init(v_opt)
            p_sch = self.prover_scheduler_init(p_opt)

            return (dict(optimizer=v_opt, lr_scheduler=v_sch, frequency=self.num_verifier_steps),
                    dict(optimizer=p_opt, lr_scheduler=p_sch, frequency=self.num_prover_steps))

        return (dict(optimizer=v_opt, frequency=self.num_verifier_steps),
                dict(optimizer=p_opt, frequency=self.num_prover_steps))

    def on_before_zero_grad(self, optimizer):
        # Calling zero_grad() on the Pytorch Lightning module zeros the gradients on all the parameters.
        self.zero_grad()

    # ____ Data related. ____
    def unpack_data_batch(self, data_batch):
        # Implement separately for each class.
        raise NotImplementedError

    def get_nonmatching_examples(self, p_xs, v_xs, ys_true, correct_proofs, other_data):
        bs = p_xs.shape[0]
        p_xs_nm, v_xs_nm = p_xs[:int(bs - 1)], v_xs[1:]
        correct_proofs_nm = correct_proofs[:int(bs - 1)] if correct_proofs is not None else None
        p_xs_ys_true, v_xs_ys_true = ys_true[:int(bs - 1)], ys_true[1:]
        ys_train_nm = torch.ones_like(ys_true[1:])
        other_data_nm = dict()  # For now, we can leave the other_data field to be empty.

        return p_xs_nm, v_xs_nm, correct_proofs_nm, p_xs_ys_true, v_xs_ys_true, ys_train_nm, other_data_nm

    def train_dataloader(self):
        return self.train_loader()

    # ____ Logging. ____
    def prepare_forward_metrics(self, **kwargs):
        # Create empty dict to save all the metrics in.
        metric_logs = dict()

        # ____ Log epoch count and batch duration. ____
        metric_logs["epoch"] = float(self.current_epoch)
        metric_logs["global_step"] = float(self.global_step)
        metric_logs["game_step"] = float(self.game_step)

        # ____ Log input and output stats. ____
        pass

        # ____ Log classification accuracy/error of the verifier. ____
        v_pred_classes = kwargs['preds'].argmax(dim=1)
        gt_0_v_0_count = ((kwargs['ys_true'] == 0) * (v_pred_classes == 0)).float().sum()
        gt_0_v_1_count = ((kwargs['ys_true'] == 0) * (v_pred_classes == 1)).float().sum()
        gt_1_v_0_count = ((kwargs['ys_true'] == 1) * (v_pred_classes == 0)).float().sum()
        gt_1_v_1_count = ((kwargs['ys_true'] == 1) * (v_pred_classes == 1)).float().sum()
        metric_logs["acc_0"] = gt_0_v_0_count / (gt_0_v_0_count + gt_0_v_1_count)
        metric_logs["acc_1"] = gt_1_v_1_count / (gt_1_v_1_count + gt_1_v_0_count)
        metric_logs["acc"] = (gt_0_v_0_count + gt_1_v_1_count) / kwargs['ys_true'].shape[0]

        # ____ Log classification accuracy/error of the classifier aux head of the prover. ____
        metric_logs = self.prepare_classifier_p_aux_accuracy_forward_metrics(metric_logs=metric_logs, **kwargs)

        # ____ Deal with logging non-matching accuracies - both of verifier and verifier aux head. ____
        metric_logs = self.prepare_nonmatching_forward_metrics(metric_logs=metric_logs, **kwargs)

        # ____ Log training losses. ____
        metric_logs["cls_loss"] = float(kwargs["classification_loss"])
        metric_logs["cls_loss_0"] = float(kwargs["cls_loss_0"])
        metric_logs["cls_loss_1"] = float(kwargs["cls_loss_1"])

        # Log prover losses.
        for name, loss in kwargs["p_aux_losses_dict"].items():
            metric_logs[f"prover_aux_loss_{name}"] = float(loss)

        # Log verifier losses.
        for name, loss in kwargs["v_aux_losses_dict"].items():
            metric_logs[f"verifier_aux_loss_{name}"] = float(loss)

        # ____ Log agent step nums. ____
        v_step_num, p_step_num = self.trainer.optimizer_frequencies
        metric_logs["verifier_step_num"] = v_step_num
        metric_logs["prover_step_num"] = p_step_num

        # ____ Prepend training mode to all keys. ____
        prepend_key = "_".join(kwargs["prepend_key"].split("/")[:-1]) + "_metrics/"
        metric_logs = prepend_string_to_dict_keys(prepend_key=prepend_key, dictinary=metric_logs)

        return metric_logs

    def prepare_classifier_p_aux_accuracy_forward_metrics(self, metric_logs, **kwargs):
        idx0 = kwargs['ys_true'] == torch.zeros_like(kwargs['ys_true'])
        idx1 = kwargs['ys_true'] == torch.ones_like(kwargs['ys_true'])
        cls_aux_preds, ys_true = kwargs["p_aux_dict"]["classification"], kwargs["ys_true"]
        metric_logs["classify_p_aux_acc_0"] = float(
            (cls_aux_preds.argmax(axis=1)[idx0] == ys_true[idx0]).float().mean())
        metric_logs["classify_p_aux_acc_1"] = float(
            (cls_aux_preds.argmax(axis=1)[idx1] == ys_true[idx1]).float().mean())

        return metric_logs

    def prepare_nonmatching_forward_metrics(self, metric_logs, **kwargs):
        ys_true = kwargs["ys_true"]
        idx0 = kwargs['ys_true'] == torch.zeros_like(kwargs['ys_true'])
        idx1 = kwargs['ys_true'] == torch.ones_like(kwargs['ys_true'])

        # Compute the relevant non-matching indices.
        if kwargs["net_idx"] == 0 and (
                "proof_input_matching" in self.verifier_aux_losses or self.use_matching_verifier_loss):
            nm_idx_00 = torch.logical_and((kwargs['nm_p_xs_labels'] == torch.zeros_like(kwargs['nm_p_xs_labels'])),
                                          (kwargs['nm_v_xs_labels'] == torch.zeros_like(kwargs['nm_v_xs_labels'])))
            nm_idx_01 = torch.logical_and((kwargs['nm_p_xs_labels'] == torch.zeros_like(kwargs['nm_p_xs_labels'])),
                                          (kwargs['nm_v_xs_labels'] == torch.ones_like(kwargs['nm_v_xs_labels'])))
            nm_idx_10 = torch.logical_and((kwargs['nm_p_xs_labels'] == torch.ones_like(kwargs['nm_p_xs_labels'])),
                                          (kwargs['nm_v_xs_labels'] == torch.zeros_like(kwargs['nm_v_xs_labels'])))
            nm_idx_11 = torch.logical_and((kwargs['nm_p_xs_labels'] == torch.ones_like(kwargs['nm_p_xs_labels'])),
                                          (kwargs['nm_v_xs_labels'] == torch.ones_like(kwargs['nm_v_xs_labels'])))
        else:
            nm_idx_00, nm_idx_01, nm_idx_10, nm_idx_11 = None, None, None, None

        # Log the accuracy of the matching verifier aux head.
        if "proof_input_matching" in self.verifier_aux_losses and kwargs["net_idx"] == 0:
            # Accuracy of the matching ones.
            metric_logs["v_matching_aux_matching_0_acc"] = self._get_acc_by_index(
                logits=kwargs["v_aux_dict"]["proof_input_matching"], targets=torch.zeros_like(ys_true), idx=idx0)
            metric_logs["v_matching_aux_matching_1_acc"] = self._get_acc_by_index(
                logits=kwargs["v_aux_dict"]["proof_input_matching"], targets=torch.zeros_like(ys_true), idx=idx1)

            # Accuracy of the non-matching ones.
            metric_logs["v_matching_aux_nonmatching_00_acc"] = self._get_acc_by_index(
                logits=kwargs["v_aux_dict_nm"]["proof_input_matching"], targets=torch.ones_like(kwargs["ys_true"][1:]),
                idx=nm_idx_00)
            metric_logs["v_matching_aux_nonmatching_01_acc"] = self._get_acc_by_index(
                logits=kwargs["v_aux_dict_nm"]["proof_input_matching"], targets=torch.ones_like(kwargs["ys_true"][1:]),
                idx=nm_idx_01)
            metric_logs["v_matching_aux_nonmatching_10_acc"] = self._get_acc_by_index(
                logits=kwargs["v_aux_dict_nm"]["proof_input_matching"], targets=torch.ones_like(kwargs["ys_true"][1:]),
                idx=nm_idx_10)
            metric_logs["v_matching_aux_nonmatching_11_acc"] = self._get_acc_by_index(
                logits=kwargs["v_aux_dict_nm"]["proof_input_matching"], targets=torch.ones_like(kwargs["ys_true"][1:]),
                idx=nm_idx_11)

        # Log the accuray of the verifier in classifying non-matching pairs.
        if self.use_matching_verifier_loss and kwargs["net_idx"] == 0:
            metric_logs["nm_acc_00"] = self._get_acc_by_index(logits=kwargs["preds_nm"], targets=kwargs["ys_train_nm"],
                                                              idx=nm_idx_00)
            metric_logs["nm_acc_01"] = self._get_acc_by_index(logits=kwargs["preds_nm"], targets=kwargs["ys_train_nm"],
                                                              idx=nm_idx_01)
            metric_logs["nm_acc_10"] = self._get_acc_by_index(logits=kwargs["preds_nm"], targets=kwargs["ys_train_nm"],
                                                              idx=nm_idx_10)
            metric_logs["nm_acc_11"] = self._get_acc_by_index(logits=kwargs["preds_nm"], targets=kwargs["ys_train_nm"],
                                                              idx=nm_idx_11)

        return metric_logs

    def log_forward_plots(self, forward_metrics, logs_dict):
        if self.plotting_fns is not None:
            for plt_fn in self.plotting_fns:
                try:
                    plt_fn(logger=self.logger, forward_metrics=forward_metrics, logs_dict=logs_dict)
                except Exception as e:
                    traceback.print_exc()
                    print(f"Encountered error while running {plt_fn}. Error: \n {e} \n Continuing. ")

    def task_specific_logging(self, metric_logs, logs_dict, **kwargs):
        # This is meant to be expanded in sub-classes.
        return metric_logs

    # ____ Saving and loading. ____
    def on_save_checkpoint(self, checkpoint):
        # Dump optimizer frequencies.
        opt_freq = self.trainer.optimizer_frequencies
        checkpoint["optimizer_frequencies"] = opt_freq

        # Dump buffers.
        checkpoint["feat_buffer_dict"] = self.feat_buffer_dict
        checkpoint["state_buffer_dict"] = self.state_buffer_dict

    def on_load_checkpoint(self, checkpoint):
        # Load optimizer frequencies.
        self.trainer.optimizer_frequencies = checkpoint["optimizer_frequencies"]

        # Load buffers.
        self.feat_buffer_dict = checkpoint["feat_buffer_dict"]
        self.state_buffer_dict = checkpoint["state_buffer_dict"]

    # ____ Miscellaneous. ____
    @staticmethod
    def get_prover_ys(ys_true, prover_mode):
        # Pick which mode to train prover with.
        if prover_mode == "collaborative":
            ys_train = ys_true
        elif prover_mode == "ljn":
            ys_train = torch.zeros_like(ys_true)
        else:
            ys_train = None
            print(f"Invalid prover mode {prover_mode}")
        return ys_train

    def use_correct_proofs(self, curr_proofs, correct_proofs, inputs):
        # Append zeros to the correct proofs to fix its shape.
        proof_size = curr_proofs.shape[-1]
        bs = inputs.shape[0]
        correct_proofs = correct_proofs.view(bs, -1)
        correct_proof_dim = correct_proofs.shape[1]
        new_correct_proofs = torch.cat((correct_proofs,
                                        torch.zeros((bs, proof_size - correct_proof_dim)).type_as(curr_proofs)), dim=1)
        new_correct_proofs = new_correct_proofs.type_as(curr_proofs)

        # Swap the proofs with the correct ones if asked to do so.
        return torch.zeros_like(curr_proofs) * curr_proofs + new_correct_proofs

    def prox_reg_sync(self):
        if self.global_step % (self.prox_reg.sync_every_n_game_steps * self.game_step_length):
            self.verifier.sync()
            self.prover.sync()

    @staticmethod
    def block_prover_outputs(curr_prover_outputs):
        return torch.zeros_like(curr_prover_outputs) * curr_prover_outputs

    def adapt_prover_strength(self, forward_metrics):
        horizon = self.adaptive_prover_step_specs["update_prover_strength_n_game_steps_after_verifier_past_cutoff"]
        upper_acc_cutoff = self.adaptive_prover_step_specs["upper_cutoff_accuracy"]
        lower_acc_cutoff = self.adaptive_prover_step_specs["lower_cutoff_accuracy"]
        middle_value = 0.5 * (lower_acc_cutoff + upper_acc_cutoff)

        # ____ Update the "verifier accuracy" buffer every new game step.  ____
        if self.global_step % (max(1, self.game_step) * self.game_step_length) == 0:
            forward_metrics = {k.split("/")[-1]: v for k, v in forward_metrics.items()}
            if "acc" not in self.state_buffer_dict:
                # Hack to make sure there always exist "lst
                self.state_buffer_dict["acc"] = [middle_value] * horizon
            self.state_buffer_dict["acc"].append(float(forward_metrics["acc"]))

        # ____ If verifier is doing well, increase prover strength and clear accuracy buffer. ____
        above_recent_cutoff = np.array(self.state_buffer_dict["acc"][int(-horizon):]) > upper_acc_cutoff
        if np.all(above_recent_cutoff):
            # Increase prover strength by allowing it to take more steps.
            prover_idx = 1
            if self.trainer.optimizer_frequencies[prover_idx] < self.adaptive_prover_step_specs["max_prover_steps"]:
                self.trainer.optimizer_frequencies[prover_idx] += 1

            # Clear the accuracy buffer.
            self.state_buffer_dict["acc"] = [middle_value] * horizon
        # ____ If verifier is doing poorly, decrease prover strength and clear accuracy buffer. ____
        below_recent_cutoff = np.array(self.state_buffer_dict["acc"][int(-horizon):]) < lower_acc_cutoff
        if np.all(below_recent_cutoff):
            # Decrease prover strength by reducing the steps it takes per game step by 1.
            prover_idx = 1
            if self.trainer.optimizer_frequencies[prover_idx] > 1:
                self.trainer.optimizer_frequencies[prover_idx] -= 1

            # Clear the accuracy buffer.
            self.state_buffer_dict["acc"] = [middle_value] * horizon

    # ____ Inspection related. ____
    def run_inspection(self):
        # Only inspect every self.do_inspection_every_n_batches.
        if not (self.global_step % self.do_inspection_every_n_batches == 0 and self.global_step != 0):
            return

        # Train the probe and log.
        self.train_probe_and_log()

        # Freeze the verifier and attack it.
        if self.perform_verifier_attack:
            self.attack_verifier()

    def _get_acc_by_index(self, logits, targets, idx):
        if idx.nelement() == 0:
            return float("nan")
        return float((logits.argmax(dim=1)[idx] == targets[idx]).float().mean())
