import copy
from tqdm import tqdm

import wandb
import torch
import dgl
from torch.nn.parameter import Parameter
import numpy as np
import matplotlib.pyplot as plt

from src.utils.misc import enlarge_matplotlib_defaults

plt.style.use('ggplot')


class VerifierAttacker(object):
    def attack_verifier(self):
        # Free GPU memory if possible.
        try:
            torch.cuda.empty_cache()
        except:
            print(f"Couldn't clear GPU cache for some reason (maybe not training on GPU). Moving on. ")
        # Clone the prover and verifier so that the training is not interfered with.
        prover_clone = copy.deepcopy(self.prover).to(self.device)
        verifier_clone = copy.deepcopy(self.verifier).to(self.device)

        # Put the verifier clone in eval mode so that its internal parameters are locked.
        verifier_clone.eval()
        prover_clone.eval()

        # Turn on requires_grad for prover_clone parameters.
        for p in prover_clone.parameters():
            p.requires_grad = True

        # Prepare for training the prover, keeping the verifier fixed.
        v_attack_dataloader = self.train_dataloader()
        v_attack_p_optimizer = self.v_attack_p_optimizer_init(prover_clone.parameters())

        # Prepare for logging losses and accuracies.
        losses_log, losses_0_log, losses_1_log = list(), list(), list()
        accs_log, accs_0_log, accs_1_log = list(), list(), list()

        # Train.
        batch_counter = 0
        while batch_counter <= self.max_v_attack_p_training_batches:
            for batch_idx, batch in tqdm(enumerate(v_attack_dataloader), total=self.max_v_attack_p_training_batches,
                                         desc="Attacking verifier. "):
                batch_counter += 1
                v_attack_p_optimizer.zero_grad()

                # Unpack batch, form prover outputs and correct proofs.
                p_xs, v_xs, ys_true, other_data = self.unpack_data_batch(data_batch=batch)
                correct_proofs = other_data["correct_proofs"] if "correct_proofs" in other_data else None

                # Move the variables into GPU if asked.
                p_xs, v_xs, ys_true = p_xs.to(self.device), v_xs.to(self.device), ys_true.to(self.device)
                if "graph" in other_data:
                    other_data["graph"] = other_data["graph"].to(self.device)

                # Decide which inputs belong to class 0 or 1.
                idx_0 = ys_true == torch.zeros_like(ys_true)
                idx_1 = ys_true == torch.ones_like(ys_true)

                # Perform forward pass and compute the loss.
                ys_train = torch.zeros_like(ys_true)
                preds, _, _, _, _, _ = self.forward(p_inputs=p_xs, v_inputs=v_xs, prover=prover_clone,
                                                    verifier=verifier_clone, correct_proofs=correct_proofs,
                                                    other_data=other_data)
                v_attack_loss = self.classification_loss_fn(preds, ys_train)
                v_attack_loss_0 = self.classification_loss_fn(preds[idx_0], ys_train[idx_0])
                v_attack_loss_1 = self.classification_loss_fn(preds[idx_1], ys_train[idx_1])

                # Perform backward pass and perform optimizer step.
                v_attack_loss.backward()
                v_attack_p_optimizer.step()

                # Compute accuracy.
                acc = (preds.argmax(axis=1) == ys_true).float().mean()
                acc_0 = (preds[idx_0].argmax(axis=1) == ys_true[idx_0]).float().mean()
                acc_1 = (preds[idx_1].argmax(axis=1) == ys_true[idx_1]).float().mean()

                # Log.
                losses_log.append(float(v_attack_loss))
                losses_0_log.append(float(v_attack_loss_0))
                losses_1_log.append(float(v_attack_loss_1))
                accs_log.append(float(acc))
                accs_0_log.append(float(acc_0))
                accs_1_log.append(float(acc_1))

                # Escape the loop.
                if not (batch_counter <= self.max_v_attack_p_training_batches):
                    break

        # Make sure that the verifier parameters have not been altered.
        assert self._compare_models(model1=self.verifier, model2=verifier_clone)

        # ____ Log the loss and accuracy plots. ____
        enlarge_matplotlib_defaults(plt_object=plt)
        fig, axes = plt.subplots(2, figsize=(15, 20))
        # Loss plot.
        axes[0].plot(losses_log, label="average loss")
        axes[0].plot(losses_0_log, label="loss given input = 0")
        axes[0].plot(losses_1_log, label="loss given input = 1")
        axes[0].set_title(f"Prover Loss Against Frozen Verifier")
        axes[0].set_xlabel("iterations")
        axes[0].set_ylabel("loss")
        axes[0].legend()

        # Accuracy plot.
        axes[1].plot(accs_log, label="average accuracy")
        axes[1].plot(accs_0_log, label="accuracy given input = 0")
        axes[1].plot(accs_1_log, label="accuracy given input = 1")
        axes[1].set_title(f"Accuracy of Frozen Verifier Against Optimized Prover")
        axes[1].set_xlabel("iterations")
        axes[1].set_ylabel("accuracy")
        axes[1].legend()

        self.logger.experiment.log({f"v_attack_plots/soundness_and_completeness_plot": wandb.Image(plt),
                                    "epoch": self.current_epoch, "global_step": self.global_step,
                                    "game_step": self.game_step})

        plt.close('all')

        # If asked, also directly optimize the proof vectors.
        if self.v_attack_proof_optim_specs is not None:
            self.attack_verifier_by_optimizing_proofs(prover_clone=prover_clone, verifier_clone=verifier_clone,
                                                      v_attack_dataloader=v_attack_dataloader)

    def attack_verifier_by_optimizing_proofs(self, prover_clone, verifier_clone, v_attack_dataloader):
        # Put the verifier clone in eval mode so that its internal parameters are locked.
        verifier_clone.eval()

        # Prepare for logging accuracies.
        num_0_inputs, num_1_inputs = 0., 0.
        num_0_correct, num_1_correct = 0., 0.
        opt_losses = list()
        relative_num_steps = list()

        # Train.
        idx_counter = 0
        for batch_idx, batch in tqdm(enumerate(v_attack_dataloader), desc="Optimizing proofs. "):
            # Decide whether to terminate or not.
            if not idx_counter < self.v_attack_proof_optim_specs["max_attack_proof_samples"]:
                break

            # Unpack batch, form prover outputs and correct proofs.
            # Also construct the training labels (i.e. all ones).
            p_xs, v_xs, ys_true, other_data = self.unpack_data_batch(data_batch=batch)
            correct_proofs = other_data["correct_proofs"] if "correct_proofs" in other_data else None
            ys_train = torch.zeros_like(ys_true)

            # Move the variables into GPU if asked.
            p_xs, v_xs, ys_true = p_xs.to(self.device), v_xs.to(self.device), ys_true.to(self.device)
            ys_train = ys_train.to(self.device)
            if "graph" in other_data:
                other_data["graph"] = other_data["graph"].to(self.device)

            # Run the prover on the current batch.
            prover_outputs, prover_aux, prover_dict = self.run_prover(p_inputs=p_xs, prover=prover_clone,
                                                                      correct_proofs=correct_proofs,
                                                                      other_data=other_data)

            # Unbatch.
            (ub_prover_outputs, ub_p_xs, ub_v_xs,
             ub_ys_train, ub_ys_true, ub_other_data) = self.unbatch_verifier_inputs(prover_outputs, p_xs, v_xs,
                                                                                    ys_train, ys_true, other_data)

            # Optimize each proof to machine precision using LBFGS.
            for i in range(ys_true.shape[0]):
                # Decide whether to terminate or not.
                if not idx_counter < self.v_attack_proof_optim_specs["max_attack_proof_samples"]:
                    break
                idx_counter += 1
                if idx_counter % 25 == 0:
                    print(f"Optimizing proof num {idx_counter}. ")

                # Get the inputs needed to run the verifier.
                (curr_prover_putputs, curr_p_inputs, curr_v_inputs, curr_ys_train,
                 curr_ys_true, curr_other_data) = self._get_verifier_inputs(i=i, ub_prover_outputs=ub_prover_outputs,
                                                                            ub_p_xs=ub_p_xs,
                                                                            ub_v_xs=ub_v_xs,
                                                                            ub_ys_train=ub_ys_train,
                                                                            ub_ys_true=ub_ys_true,
                                                                            ub_other_data=ub_other_data)
                # Initialize the proof as a PyTorch parameter and initialize the optimizer.
                curr_p_out_params = Parameter(curr_prover_putputs).to(self.device)
                p_out_optim = self.v_attack_proof_optim_specs["proof_optimizer"](params=[curr_p_out_params])

                # Run the optimization procedure.
                def closure():
                    if torch.is_grad_enabled():
                        p_out_optim.zero_grad()
                    # Process prover outputs and get proofs.
                    curr_proofs_, curr_prover_outputs_ = self.process_prover_outputs(curr_p_inputs, curr_p_out_params,
                                                                                     correct_proofs=None)
                    curr_preds, _, _ = self.run_verifier(proofs=curr_proofs_, v_inputs=curr_v_inputs,
                                                         verifier=verifier_clone, other_data=curr_other_data)
                    curr_loss = self.classification_loss_fn(curr_preds, curr_ys_train)
                    curr_loss.backward()
                    return curr_loss

                p_out_optim.step(closure=closure)
                # print(f"Previous loss: {proof_optim.state_dict()['state'][0]['prev_loss']} \n"
                #       f"Num steps: {proof_optim.state_dict()['state'][0]['n_iter']} \n"
                #       f"History size: {len(proof_optim.state_dict()['state'][0]['old_dirs'])}")

                # Get the verifier outputs with the optimized proof.
                curr_proofs, curr_prover_outputs = self.process_prover_outputs(curr_p_inputs, curr_p_out_params,
                                                                               correct_proofs)
                curr_pred, _, _ = self.run_verifier(proofs=curr_proofs, v_inputs=curr_v_inputs,
                                                    verifier=verifier_clone, other_data=curr_other_data)

                # Keep track of accuracy values and other optimizer values to log later.
                curr_acc = float((curr_pred.argmax(dim=1) == curr_ys_true).float())
                if bool(curr_ys_true == 0):
                    num_0_inputs += 1
                    num_0_correct += curr_acc
                if bool(curr_ys_true == 1):
                    num_1_inputs += 1
                    num_1_correct += curr_acc
                relative_num_step = p_out_optim.state_dict()['state'][0]['n_iter'] / p_out_optim.defaults['max_iter']
                relative_num_steps.append(relative_num_step)

        # Log.
        acc_0 = num_0_correct / num_0_inputs
        acc_1 = num_1_correct / num_1_inputs
        avg_rel_num_steps = np.array(relative_num_steps).mean()

        self.logger.experiment.log({f"v_attack_proof_optim/acc_0": float(acc_0),
                                    f"v_attack_proof_optim/acc_1": float(acc_1),
                                    f"v_attack_proof_optim/avg_relative_num_steps": float(avg_rel_num_steps),
                                    "epoch": self.current_epoch,
                                    "global_step": self.global_step,
                                    "game_step": self.game_step})

    def unbatch_verifier_inputs(self, prover_outputs, p_xs, v_xs, ys_train, ys_true, other_data):
        bs = ys_true.shape[0]
        ub_prover_outputs = prover_outputs.split(1)
        ub_p_xs = p_xs.split(1)
        ub_v_xs = v_xs.split(1)
        ub_ys_train = ys_train.split(1)
        ub_ys_true = ys_true.split(1)
        ub_other_data = list()
        for i in range(bs):
            curr_other_data = dict()
            for k, v in other_data.items():
                if isinstance(v, torch.Tensor):
                    curr_other_data[k] = v[i:i + 1]
            ub_other_data.append(curr_other_data)

        return ub_prover_outputs, ub_p_xs, ub_v_xs, ub_ys_train, ub_ys_true, ub_other_data

    def _get_verifier_inputs(self, i, ub_prover_outputs, ub_p_xs, ub_v_xs, ub_ys_train, ub_ys_true, ub_other_data):
        curr_prover_outputs = ub_prover_outputs[i]
        curr_p_inputs = ub_p_xs[i]
        curr_v_inputs = ub_v_xs[i]
        curr_ys_train = ub_ys_train[i]
        curr_ys_true = ub_ys_true[i]
        curr_other_data = ub_other_data[i]
        return curr_prover_outputs, curr_p_inputs, curr_v_inputs, curr_ys_train, curr_ys_true, curr_other_data

    @staticmethod
    def _compare_models(model1, model2):
        model1_params_dict = dict(model1.named_parameters())
        model2_params_dict = dict(model2.named_parameters())
        if not set(model1_params_dict.keys()) == set(model2_params_dict.keys()):
            return False

        for name, param in model1_params_dict.items():
            if not torch.allclose(param, model2_params_dict[name]):
                return False
        return True
