from spaghettini import quick_register
from abc import ABC

import torch
import pytorch_lightning as pl

from src.utils.misc import prepend_string_to_dict_keys

SMALL = 1e-8


@quick_register
class LJNDoubleBase(pl.LightningModule, ABC):
    def __init__(self, verifier, prover1, prover2, train_loader, verifier_optimizer_init, prover1_optimizer_init,
                 prover2_optimizer_init, prover_lr, classification_loss_fn, verifier_cls_loss_coeff,
                 prover1_cls_loss_coeff, prover2_cls_loss_coeff, aux_loss_fn, verifier_aux_loss_coeff,
                 prover1_aux_loss_coeff, prover2_aux_loss_coeff, col_weighting, adv_weighting_fn, verifier_scheduler,
                 prover1_scheduler, prover2_scheduler, give_correct_proof_until, block_proof_until, proof_size,
                 prover_1_mode, prover_2_mode, ys_aux_type, log_every_n_batch, task_name, train_p_aux_head,
                 train_v_aux_head, lookahead=None, use_hybrid_loss=False, **kwargs):
        super().__init__()
        assert prover_1_mode in ["collaborative", "ljn"] and prover_2_mode in ["collaborative", "ljn"]
        assert ys_aux_type in ["causal_feature", "label"]

        self.verifier = verifier
        self.prover1 = prover1
        self.prover2 = prover2

        # Data loaders.
        self.train_loader = train_loader

        # Optimizers and learning rate schedules.
        self.v_opt_init = verifier_optimizer_init
        self.p1_opt_init = prover1_optimizer_init
        self.p2_opt_init = prover2_optimizer_init
        self.prover_lr = prover_lr

        self.verifier_scheduler_init = verifier_scheduler
        self.prover1_scheduler_init = prover1_scheduler
        self.prover2_scheduler_init = prover2_scheduler

        # Loss function coefficients.
        self.classification_loss_fn = classification_loss_fn
        self.verifier_cls_loss_coeff = verifier_cls_loss_coeff
        self.prover1_cls_loss_coeff = prover1_cls_loss_coeff
        self.prover2_cls_loss_coeff = prover2_cls_loss_coeff

        self.aux_loss_fn = aux_loss_fn
        self.verifier_aux_loss_coeff = verifier_aux_loss_coeff
        self.prover1_aux_loss_coeff = prover1_aux_loss_coeff
        self.prover2_aux_loss_coeff = prover2_aux_loss_coeff

        self.col_weighting = col_weighting
        self.adv_weighting_fn = adv_weighting_fn

        # Training setup.
        self.give_correct_proof_until = give_correct_proof_until
        self.block_proof_until = block_proof_until

        self.proof_size = proof_size
        self.prover_1_mode = prover_1_mode
        self.prover_2_mode = prover_2_mode
        self.ys_aux_type = ys_aux_type
        self.task_name = task_name

        # Convenience lookups.
        self.optimizers = None  # To be filled later.
        self.nets = [self.verifier, self.prover1, self.prover2]

        # Other.
        self.log_every_n_batch = log_every_n_batch
        self.visualize_proof = kwargs["visualize_proof"]
        self.train_p_aux_head = train_p_aux_head
        self.train_v_aux_head = train_v_aux_head
        self.lookahead = lookahead
        self.use_hybrid_loss = use_hybrid_loss

        # Warnings.
        print("Warning 1: The mean operation in the loss computation is WRONG if classes are not balanced. ")

    def forward(self, p_inputs, v_inputs, prover_idx, correct_proofs):
        prover = self.nets[prover_idx]

        # ____Generate all the proofs. ____
        [proofs, prover_aux], prover_dict = prover(p_inputs), dict()

        # ____ Block the proofs if asked to do so. ____
        if self.current_epoch < self.block_proof_until:
            proofs = self.block_proofs(curr_proofs=proofs)

        # ____ Use the correct proofs if asked to do so. ____
        if self.current_epoch < self.give_correct_proof_until:
            proofs = self.use_correct_proofs(curr_proofs=proofs, correct_proofs=correct_proofs, inputs=p_inputs)

        # ____ Run the verifier. ____
        preds, verifier_aux, model_dict = self.verifier(v_inputs, [proofs])
        model_dict.update(prover_dict)

        return preds, prover_aux, verifier_aux, model_dict

    # ____ Training. ____
    def training_step(self, data_batch, batch_nb, optimizer_idx, *args, **kwargs):
        # ____ Unpack data batch. ____
        # Get the inputs and ground truth labels and auxiliary task labels.
        p_xs, v_xs, ys_true, coords = self.unpack_data_batch(data_batch)

        # Run forward pass.
        if optimizer_idx == 0:  # Training verifier.
            # Run with prover 1. Decide which ys_train and ys_aux to use.
            ys_aux = self.get_ys_aux(causal_feature=coords, ys_true=ys_true, prover_idx=1)
            ys_train = ys_true
            p1_loss, p1_metric_logs, p1_logs_dict = self.common_step(p_xs=p_xs, v_xs=v_xs, ys_train=ys_train,
                                                                     ys_true=ys_true, ys_aux=ys_aux, net_idx=0,
                                                                     prover_idx=1, batch_nb=batch_nb,
                                                                     prepend_key="training/")

            # Run with prover 2. Decide which ys_train and ys_aux to use.
            ys_aux = self.get_ys_aux(causal_feature=coords, ys_true=ys_true, prover_idx=2)
            ys_train = ys_true
            p2_loss, p2_metric_logs, p2_logs_dict = self.common_step(p_xs=p_xs, v_xs=v_xs, ys_train=ys_train,
                                                                     ys_true=ys_true, ys_aux=ys_aux, net_idx=0,
                                                                     prover_idx=2, batch_nb=batch_nb,
                                                                     prepend_key="training/")
            # Combine results.
            loss = 0.5 * (p1_loss + p2_loss)
            logs_dict = dict(p1_logs_dict, **p2_logs_dict)
            metric_logs = dict(p1_metric_logs, **p2_metric_logs)  # This combines the dicts.

        elif optimizer_idx == 1:  # Training prover 1.
            # Pick which mode to train prover 1 with.
            if self.prover_1_mode == "collaborative":
                ys_train = ys_true
            elif self.prover_1_mode == "ljn":
                ys_train = torch.zeros_like(ys_true)
            else:
                ys_train = None
                print(f"Invalid prover 1 mode {self.prover_1_mode}")

            # Pick which ys_aux to use.
            ys_aux = self.get_ys_aux(causal_feature=coords, ys_true=ys_true, prover_idx=1)

            # Run forwards pass.
            loss, metric_logs, logs_dict = self.common_step(p_xs=p_xs, v_xs=v_xs, ys_train=ys_train, ys_true=ys_true,
                                                            ys_aux=ys_aux, net_idx=1, prover_idx=1, batch_nb=batch_nb,
                                                            prepend_key="training/")
        elif optimizer_idx == 2:  # Training prover 2.
            # Pick which mode to train prover 2 with.
            if self.prover_2_mode == "collaborative":
                ys_train = ys_true
            elif self.prover_2_mode == "ljn":
                ys_train = torch.ones_like(ys_true)
            else:
                ys_train = None
                print(f"Invalid prover 2 mode {self.prover_2_mode}")

            # Pick which ys_aux to use.
            ys_aux = self.get_ys_aux(causal_feature=coords, ys_true=ys_true, prover_idx=2)

            # Run forwards pass.
            loss, metric_logs, logs_dict = self.common_step(p_xs=p_xs, v_xs=v_xs, ys_train=ys_train, ys_true=ys_true,
                                                            ys_aux=ys_aux, net_idx=2, prover_idx=2, batch_nb=batch_nb,
                                                            prepend_key="training/")
        else:
            loss, metric_logs, logs_dict = None, None, None
            print(f"Invalid optimizer idx {optimizer_idx}. Aborting")
            exit(-1)

        # ____ Log metrics. ____
        if batch_nb % self.log_every_n_batch == 0:
            self.logger.experiment.log(metric_logs)  # , step=self.global_step)

        return {"loss": loss, "logs_dict": logs_dict}

    # ____ Common step. ____
    def common_step(self, p_xs, v_xs, ys_train, ys_true, ys_aux, net_idx, prover_idx, batch_nb, prepend_key=""):
        assert (self.training and (prepend_key == "training/")) or \
               (not self.training and (prepend_key == "validation/")) or \
               (not self.training and (prepend_key == "testing/"))
        # Update the prepend key.
        prepend_key = prepend_key + f"t_{net_idx}_p_{prover_idx}/"

        # Perform forward pass.
        preds, p_aux, v_aux, model_dict = self.forward(p_inputs=p_xs, v_inputs=v_xs, prover_idx=prover_idx,
                                                       correct_proofs=ys_aux)

        # ____ Compute losses. ____
        # Decompose the collaborative and adversarial components of the loss and weigth them separately.
        col_idxs, adv_idxs = (ys_train == ys_true), (ys_train != ys_true)
        col_cls_loss = self.classification_loss_fn(preds[col_idxs], ys_train[col_idxs]) if col_idxs.sum() > 0 else 0
        adv_cls_loss = self.classification_loss_fn(preds[adv_idxs], ys_train[adv_idxs]) if adv_idxs.sum() > 0 else 0

        cls_loss = self.col_weighting * col_cls_loss + self.adv_weighting_fn(self.current_epoch) * adv_cls_loss
        cls_loss_coeff = self.get_cls_loss_coeff(net_idx=net_idx)

        # Compute the auxiliary losses. Decompose these into collaborative and adversarial.
        (p_aux_loss, p_aux_col_loss, p_aux_adv_loss,
         v_aux_loss, v_aux_col_loss, v_aux_adv_loss) = self.compute_aux_losses(p_xs=p_xs, p_aux=p_aux, v_aux=v_aux,
                                                                               ys_aux=ys_aux, col_idxs=col_idxs,
                                                                               adv_idxs=adv_idxs)
        p_aux_loss, v_aux_loss = p_aux_col_loss + p_aux_adv_loss, v_aux_col_loss + v_aux_adv_loss
        aux_loss = (p_aux_loss + v_aux_loss)

        # Compute total loss.
        aux_loss_coeff = self.get_aux_loss_coeff(epoch=self.current_epoch, net_idx=net_idx)
        total_loss = cls_loss_coeff * cls_loss + aux_loss_coeff * aux_loss

        # ____ Log relevant data. ____
        logs_dict = dict(p_xs=p_xs, v_xs=v_xs, ys_train=ys_train, ys_true=ys_true, ys_aux=ys_aux, preds=preds,
                         p_aux=p_aux, v_aux=v_aux,
                         classification_loss=cls_loss,
                         prover_aux_loss=p_aux_loss,
                         prover_aux_col_loss=p_aux_col_loss,
                         prover_aux_adv_loss=p_aux_adv_loss,
                         verifier_aux_loss=v_aux_loss,
                         verifier_aux_col_loss=v_aux_col_loss,
                         verifier_aux_adv_loss=v_aux_adv_loss,
                         col_idxs=col_idxs, adv_idxs=adv_idxs,
                         model_logs=model_dict,
                         verifier=self.verifier,
                         net_idx=net_idx, prover_idx=prover_idx, batch_nb=batch_nb, prepend_key=prepend_key)
        metric_logs = self.prepare_metric_logs(**logs_dict)
        metric_logs = self.task_specific_logging(metric_logs=metric_logs, **logs_dict)
        logs_dict = prepend_string_to_dict_keys(prepend_key=prepend_key, dictinary=logs_dict)

        return total_loss, metric_logs, logs_dict

    def compute_aux_losses(self, p_xs, p_aux, v_aux, ys_aux, col_idxs, adv_idxs):
        col_idxs_aux = col_idxs
        adv_idxs_aux = adv_idxs

        if len(ys_aux.shape) == 1:
            ys_aux = ys_aux[..., None].float()  # So that the dimensionalities match with the output of the aux heads.
        p_aux_loss = self.aux_loss_fn(p_aux, ys_aux) if self.train_p_aux_head else 0.
        p_aux_col_loss = self.aux_loss_fn(p_aux[col_idxs_aux], ys_aux[col_idxs_aux]) if self.train_p_aux_head else 0
        p_aux_adv_loss = self.aux_loss_fn(p_aux[adv_idxs_aux], ys_aux[adv_idxs_aux]) if self.train_p_aux_head else 0

        v_aux_loss = self.aux_loss_fn(v_aux, ys_aux) if self.train_v_aux_head else 0.
        v_aux_col_loss = self.aux_loss_fn(v_aux[col_idxs_aux], ys_aux[col_idxs_aux]) if self.train_v_aux_head else 0
        v_aux_adv_loss = self.aux_loss_fn(v_aux[adv_idxs_aux], ys_aux[adv_idxs_aux]) if self.train_v_aux_head else 0

        return p_aux_loss, p_aux_col_loss, p_aux_adv_loss, v_aux_loss, v_aux_col_loss, v_aux_adv_loss

    # ____ Optimizers. ____
    def configure_optimizers(self):
        v_opt = self.v_opt_init(self.verifier.parameters())
        p1_opt = self.p1_opt_init(self.prover1.parameters(), lr=self.prover_lr)
        p2_opt = self.p2_opt_init(self.prover2.parameters(), lr=self.prover_lr)

        # Use the lookahead optimizer if asked.
        if self.lookahead is not None:
            v_opt = self.lookahead(v_opt)
            p1_opt = self.lookahead(p1_opt)
            p2_opt = self.lookahead(p2_opt)

        self.optimizers = [v_opt, p1_opt, p2_opt]

        # Use a learning rate scheduler if asked.
        if self.verifier_scheduler_init is not None:
            v_sch = self.verifier_scheduler_init(v_opt)
            p1_sch = self.prover1_scheduler_init(p1_opt)
            p2_sch = self.prover2_scheduler_init(p2_opt)

            return [v_opt, p1_opt, p2_opt], [v_sch, p1_sch, p2_sch]

        return [v_opt, p1_opt, p2_opt]

    def on_before_zero_grad(self, optimizer):
        # Calling zero_grad() on the Pytorch Lightning module zeros the gradients on all the parameters.
        self.zero_grad()

    def unpack_data_batch(self, data_batch):
        raise NotImplementedError

    def train_dataloader(self):
        return self.train_loader()

    # ____ Logging. ____
    def prepare_metric_logs(self, p_xs, v_xs, ys_train, ys_true, ys_aux, preds, p_aux, v_aux, classification_loss,
                            prover_aux_loss, prover_aux_col_loss, prover_aux_adv_loss, verifier_aux_col_loss,
                            verifier_aux_adv_loss, verifier_aux_loss,
                            model_logs, net_idx, prover_idx, batch_nb, prepend_key, **kwargs):
        # Create empty dict to save all the metrics in.
        metric_logs = dict()

        # ____ Log epoch count and batch duration. ____
        metric_logs["epoch"] = float(self.current_epoch)

        # ____ Log input and output stats. ____
        pass

        # ____ Log classification accuracy/error. ____
        v_pred_classes = preds.argmax(dim=1)
        gt_0_v_0_count = ((ys_true == 0) * (v_pred_classes == 0)).float().sum()
        gt_0_v_1_count = ((ys_true == 0) * (v_pred_classes == 1)).float().sum()
        gt_1_v_0_count = ((ys_true == 1) * (v_pred_classes == 0)).float().sum()
        gt_1_v_1_count = ((ys_true == 1) * (v_pred_classes == 1)).float().sum()
        total_count = ys_true.shape[0]
        metric_logs["gt_0_v_0_count"] = (gt_0_v_0_count / total_count)
        metric_logs["gt_0_v_1_count"] = (gt_0_v_1_count / total_count)
        metric_logs["gt_1_v_0_count"] = (gt_1_v_0_count / total_count)
        metric_logs["gt_1_v_1_count"] = (gt_1_v_1_count / total_count)

        # ____ Log training loss. ____
        metric_logs["cls_loss"] = float(classification_loss)

        metric_logs["prover_aux_loss"] = float(prover_aux_loss)
        metric_logs["prover_aux_col_loss"] = float(prover_aux_col_loss)
        metric_logs["prover_aux_adv_loss"] = float(prover_aux_adv_loss)

        metric_logs[f"verifier_aux_loss_p_{prover_idx}"] = float(verifier_aux_loss)
        metric_logs[f"verifier_aux_col_loss_p_{prover_idx}"] = float(verifier_aux_col_loss)
        metric_logs[f"verifier_aux_adv_loss_p_{prover_idx}"] = float(verifier_aux_adv_loss)

        # ____ Log any scheduled quantity. ____
        metric_logs["adv_loss_coeff"] = self.adv_weighting_fn(self.current_epoch)
        metric_logs["prover_aux_loss_coeff"] = self.get_aux_loss_coeff(epoch=self.current_epoch, net_idx=net_idx)

        # ____ Prepend training mode to all keys. ____
        metric_logs = prepend_string_to_dict_keys(prepend_key=prepend_key, dictinary=metric_logs)

        return metric_logs

    def task_specific_logging(self, metric_logs, **kwargs):
        return metric_logs

    # ____ Miscellaneous. ____
    def get_ys_aux(self, causal_feature, ys_true, prover_idx):
        if self.ys_aux_type == "causal_feature":
            return causal_feature
        elif self.ys_aux_type == "label":
            return ys_true

    def get_cls_loss_coeff(self, net_idx):
        if net_idx == 0:  # Verifier.
            return self.verifier_cls_loss_coeff
        elif net_idx == 1:  # Prover 1
            return self.prover1_cls_loss_coeff
        elif net_idx == 2:  # Prover 2.
            return self.prover2_cls_loss_coeff

    def get_aux_loss_coeff(self, epoch, net_idx):
        if net_idx == 0:  # Verifier.
            return self.verifier_aux_loss_coeff
        elif net_idx == 1:  # Prover 1
            return self.prover1_aux_loss_coeff(epoch=epoch)
        elif net_idx == 2:  # Prover 2.
            return self.prover2_aux_loss_coeff(epoch=epoch)

    def use_correct_proofs(self, curr_proofs, correct_proofs, inputs):
        # Append zeros to the correct to fix its shape.
        bs = inputs.shape[0]
        correct_proofs = correct_proofs.view(bs, -1)
        correct_proof_dim = correct_proofs.shape[1]
        new_correct_proofs = torch.cat((correct_proofs,
                                        torch.zeros((bs,
                                                     self.proof_size - correct_proof_dim)).type_as(curr_proofs)),
                                       dim=1)
        new_correct_proofs = new_correct_proofs.type_as(curr_proofs)

        # Swap the proofs with the correct ones if asked to do so.
        return torch.zeros_like(curr_proofs) * curr_proofs + new_correct_proofs

    @staticmethod
    def block_proofs(curr_proofs):
        return torch.zeros_like(curr_proofs) * curr_proofs
