from spaghettini import quick_register

import numpy as np
import torch
import wandb

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from src.learners.pvn_classification.ljn_single_base import LJNSingleBase

colors = list(mcolors.TABLEAU_COLORS.keys())
LINEWIDTH = 12
VERY_BIG = 1e10

from src.utils.gumbel_softmax import gumbel_softmax
from src.utils.misc import enlarge_matplotlib_defaults


@quick_register
class LJNSingleBinaryErasure(LJNSingleBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.prover_logits_traj = list()
        self.verifier_logits_traj = list()
        self.acc_0_traj = list()
        self.acc_1_traj = list()
        self.channel_temperature = 1.

    def channel_fn(self, proofs, inputs):
        # Assume that the proofs constitute the logits for a categorical distribution.
        logits = proofs
        num_categories = logits.shape[-1]

        # Create a random number generator to use during the forward pass.
        rng = torch.Generator()
        rng = rng.manual_seed(self.global_step)

        # If input is not 0, make sure the logit of sending 0 is very small.
        # Likewise, if input is not 1, make sure the logit of sending 1 is very small.
        flipped_inputs = torch.ones_like(inputs) - inputs
        logit_subtract_one_hot = (torch.argmax(flipped_inputs, dim=1)[..., None] ==
                                  torch.arange(num_categories).type_as(flipped_inputs)[None, ...]).float()
        logits = logits * (1. - logit_subtract_one_hot) - VERY_BIG * logit_subtract_one_hot

        # Sample tokens using the Straight-through Gumbel-Softmax trick.
        # Make the sampling process deterministic.
        proofs = gumbel_softmax(logits=logits, temperature=self.channel_temperature, rng=rng)

        return proofs

    def unpack_data_batch(self, data_batch):
        xs, ys = data_batch

        p_xs, v_xs = xs, torch.zeros_like(xs)

        other_data = dict(correct_proofs=None)

        return p_xs, v_xs, ys, other_data

    def task_specific_logging(self, metric_logs, logs_dict, **kwargs):
        # plt.style.use("fivethirtyeight")
        if logs_dict["net_idx"] == 0:
            # ____ Get relevant tensors into numpy format. ____
            prover_logits = logs_dict["model_logs"]["prover_logits"].detach().cpu().numpy()
            proof_samples = logs_dict["model_logs"]["proofs"].detach().cpu().numpy()
            verifier_logits = logs_dict["model_logs"]["verifier_logits"].detach().cpu().numpy()
            prepend_key = logs_dict["prepend_key"]
            xs = logs_dict["p_xs"].detach().cpu().numpy()

            # ____ Keep track of the prover and verifier conditional logits. ____
            # Form input-conditional prover logits and track them.
            averaged_logits_x_0 = prover_logits[xs[:, 0] == 1.].mean(axis=0)
            averaged_logits_x_1 = prover_logits[xs[:, 1] == 1.].mean(axis=0)
            averaged_logits = np.concatenate((averaged_logits_x_0[None, ...], averaged_logits_x_1[None, ...]), axis=0)
            self.prover_logits_traj.append(averaged_logits)

            # Form input-conditional verifier logits and track them.
            dim_proof = proof_samples.shape[-1]
            verifier_conditional_logits = np.zeros(shape=(dim_proof, 2))
            proof_tokens = np.argmax(proof_samples, axis=1)
            token_ids = np.arange(dim_proof)
            for token_id in token_ids:
                try:
                    curr_ex = np.where(proof_tokens == token_id)[0][0]
                    verifier_conditional_logits[token_id, :] = verifier_logits[curr_ex, :]
                except:
                    verifier_conditional_logits[token_id, :] = float('nan')

            self.verifier_logits_traj.append(verifier_conditional_logits)

            # Track 0 accuracy and 1 accuracy.
            metric_logs_clean = {k.split("/")[-1]: v for k, v in metric_logs.items()}
            acc0, acc1 = float(metric_logs_clean["acc_0"]), float(metric_logs_clean["acc_1"])
            self.acc_0_traj.append(acc0)
            self.acc_1_traj.append(acc1)

            # ____ Plot prover and verifier logits, as well as accuracy trajectories. ____
            enlarge_matplotlib_defaults(plt_object=plt)
            num_plots = 3
            if self.global_step != 0 and logs_dict["batch_nb"] == 0:
                fig, all_axes = plt.subplots(num_plots, 1, figsize=(15, 20))
                all_axes = [all_axes]

                # Plot the prover logits.
                axs = all_axes[0]
                lp_traj = np.array(self.prover_logits_traj)
                ts = np.arange(len(lp_traj[:, 0, 0]))

                for pi in range(lp_traj.shape[1]):
                    linestyle = "dotted" if pi == 0 else "dashdot"
                    for pj in range(lp_traj.shape[2] - 2):
                        color = colors[(pj + 2) % len(colors)]
                        axs[0].plot(ts, lp_traj[:, pi, pj + 2], linestyle=linestyle, color=color)

                axs[0].plot(ts, lp_traj[:, 0, 0], color=colors[0], linestyle="dotted", label=f"input 0 - output 0 logits",
                            linewidth=LINEWIDTH)
                axs[0].plot(ts, lp_traj[:, 1, 1], color=colors[1], linestyle="dashdot", label=f"input 1 - output 1 logits",
                            linewidth=LINEWIDTH)

                axs[0].set_ylabel("logits")
                axs[0].set_xlabel("iterations")
                axs[0].legend()
                axs[0].set_title(f"Prover Logits")

                # ____ Visualize the verifier logits. ____
                # Plot verifier logits.
                lv_traj = np.array(self.verifier_logits_traj)
                lv0_traj = lv_traj[..., 0] - lv_traj[..., 1]

                for vi in range(lv0_traj.shape[1]):
                    color = colors[vi % len(colors)]
                    if vi not in [0, 1]:
                        axs[1].plot(ts, lv0_traj[:, vi], linestyle="solid", color=color)

                axs[1].plot(ts, lv0_traj[:, 0], linestyle="solid", color=colors[0], label=f"input {0} - output 0 logits", linewidth=LINEWIDTH)
                axs[1].plot(ts, lv0_traj[:, 1], linestyle="solid", color=colors[1], label=f"input {1} - output 0 logits", linewidth=LINEWIDTH)

                axs[1].set_ylabel("logits")
                axs[1].set_xlabel("iterations")
                axs[1].legend()
                axs[1].set_title(f"Verifier logits")

                # ____ Plot 0 and 1 accuracies. ____
                axs[2].plot(ts, self.acc_0_traj, label="accuracy given input = 0")
                axs[2].plot(ts, self.acc_1_traj, label="accuracy given input = 1")
                axs[2].set_xlabel("iterations")
                axs[2].set_ylabel("accuracy")
                axs[2].legend()
                axs[2].set_title(f"Label Conditioned Accuracy Values")
                plt.tight_layout(pad=1.5)

                # Log.
                self.logger.experiment.log({f"{prepend_key}_logit_visualizations": wandb.Image(plt),
                                            "epoch": self.current_epoch, "global_step": self.global_step,
                                            "game_step": self.game_step})
                plt.close('all')

        return metric_logs
