from tqdm import tqdm
import collections

import wandb
import torch
import numpy as np

import matplotlib.pyplot as plt

plt.style.use('ggplot')


class ProverPretrainer(object):
    def run_prover_pretraining(self):
        print(f"Task: pretraining. Device: {self.device}")
        # Get pretraining dataloader and optimizer.
        pretrain_loader = self.train_dataloader()
        pretrain_optim = self.pretraining_specs["optimizer"](self.prover.parameters())
        if "scheduler" in self.pretraining_specs:
            pretrain_lr_scheduler = self.pretraining_specs["scheduler"](optimizer=pretrain_optim)

        # Prepare to log relevant quantities.
        cls_aux_losses = list()
        autoenc_aux_losses = list()
        cls_aux_accs = list()

        # Run pretraining.
        batch_counter = 0
        acc_horizon = collections.deque([0.], maxlen=self.pretraining_specs["stop_pretraining_acc_horizon"])
        while batch_counter < self.pretraining_specs["pretrain_for_n_batches"]:
            for i, batch in tqdm(enumerate(pretrain_loader), total=self.pretraining_specs["pretrain_for_n_batches"],
                                 desc="Pretraining"):
                # Update batch counter and break if necessary.
                if not (batch_counter < self.pretraining_specs["pretrain_for_n_batches"]):
                    break
                batch_counter += 1

                # Zero-grad.
                self.prover.zero_grad()

                # Unpack batch and move to GPU if possible.
                p_xs, v_xs, ys_true, other_data = self.unpack_data_batch(batch)
                p_xs = p_xs.to(self.device)
                v_xs = v_xs.to(self.device)
                ys_true = ys_true.to(self.device)

                # Also move the graph to GPU - only applicable for GNN tasks.
                if "graph" in other_data:
                    other_data["graph"] = other_data["graph"].to(self.device)

                # Run prover forward pass.
                _, p_aux_dict, prover_dict = self.prover(p_xs, **other_data)

                # Compute prover auxiliary losses.
                p_aux_losses_dict = self.compute_prover_aux_losses(p_aux_dict=p_aux_dict, p_xs=p_xs, ys_true=ys_true,
                                                                   other_data=other_data)
                p_aux_loss = sum([value for name, value in p_aux_losses_dict.items()])

                # Add label smoothing to the classifier auxiliary head.
                num_cls = 2
                label_smt = ((-1. / num_cls) * torch.log_softmax(p_aux_dict["classification"], dim=1)).sum(dim=1).mean()
                scaled_label_smt = self.pretraining_specs["label_smoothing_coeff"] * label_smt

                # Add collapse preventing loss.
                if self.pretraining_specs.collapse_preventing_reg is not None:
                    # Add collapse preventing loss. Apply to prover features as well as proofs.
                    max_cp_samples = self.pretraining_specs.collapse_preventing_reg.max_samples
                    feats = prover_dict["prover_feats"]
                    cpl_on_feats = self.pretraining_specs.collapse_preventing_reg.reg_fn(feats=feats[:max_cp_samples],
                                                                                         targets=ys_true[
                                                                                                 :max_cp_samples])
                else:
                    cpl_on_feats = 0.

                # Compute the total loss.
                total_loss = p_aux_loss + scaled_label_smt + cpl_on_feats

                # Perform backward pass and update parameters.
                total_loss.backward()
                pretrain_optim.step()
                self.prover.zero_grad()

                # Take scheduler step.
                if "scheduler" in self.pretraining_specs:
                    pretrain_lr_scheduler.step()

                # Log.
                cls_acc = float((p_aux_dict["classification"].argmax(dim=1) == ys_true).float().mean())
                acc_horizon.append(cls_acc)
                cls_aux_losses.append(float(p_aux_losses_dict["classification"]))
                autoenc_aux_losses.append(float(p_aux_losses_dict["autoencoding"]))
                cls_aux_accs.append(cls_acc)

                # Print and directly log.
                if batch_counter % 10 == 0:
                    print(
                        f"Iteration: {batch_counter}, acc: {float((p_aux_dict['classification'].argmax(dim=1) == ys_true).float().mean())}")
                    self.logger.experiment.log(dict(classification_acc=cls_acc,
                                                    classification_loss=float(p_aux_losses_dict["classification"]),
                                                    autoenc_loss=float(p_aux_losses_dict["autoencoding"]),
                                                    pretraining_iter=batch_counter))
                    # Log learning rate schedule, if possible
                    if "scheduler" in self.pretraining_specs:
                        curr_lr = pretrain_optim.param_groups[0]["lr"]
                        self.logger.experiment.log(dict(curr_lr=curr_lr, pretraining_iter=batch_counter))

                # Check if accuracy is high enough that we can stop pretraining.
                acc_satisfactory = np.all(
                    np.array(list(acc_horizon)) > self.pretraining_specs["stop_pretraining_if_acc_above_threshold"])
                if acc_satisfactory:
                    break
            acc_satisfactory = np.all(
                np.array(list(acc_horizon)) > self.pretraining_specs["stop_pretraining_if_acc_above_threshold"])
            if acc_satisfactory:
                break

        # ____ Plot the logged quantities. ____
        logging_dict = dict(classification_loss=cls_aux_losses, autoencoding_loss=autoenc_aux_losses,
                            classification_accuracy=cls_aux_accs)
        num_figs = len(logging_dict)
        fig, axs = plt.subplots(nrows=1, ncols=num_figs, figsize=(6, 3))

        # Plot the classification aux head loss.
        for i_col, (log_name, log_val) in enumerate(logging_dict.items()):
            axs[i_col].plot(log_val)
            axs[i_col].set_title(log_name)
        plt.tight_layout(pad=1.5)

        # ____ Log. ____
        self.logger.experiment.log(
            {f"pretraining/prover_aux_head_performances": wandb.Image(plt),
             "epoch": self.current_epoch,
             "global_step": self.global_step,
             "game_step": self.game_step})

        plt.close('all')
