from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
import torch
from mlwiz.static import TRAINING, VALIDATION, TEST
from mlwiz.training.callback.plotter import Plotter
from mlwiz.training.event.state import State

from model import AWN


class WidthPlotter(Plotter):

    def on_epoch_end(self, state: State):
        super().on_epoch_end(state)

        variational_Ws = state.model.variational_Ws

        # Plot one graph with the different widths
        qW_probs = [w.compute_probability_vector() for w in variational_Ws]
        widths = [p.shape[0] for p in qW_probs]

        widths_dict = {f"width_{i+1}": widths[i] for i in range(len(widths))}
        self.writer.add_scalars(f"Width", widths_dict, state.epoch)

        if "model_widths" not in self.stored_metrics:
            self.stored_metrics["model_widths"] = [widths]
        else:
            self.stored_metrics["model_widths"].append(widths)

        if self.store_on_disk:
            try:
                torch.save(self.stored_metrics, self.stored_metrics_path)
            except RuntimeError as e:
                print(e)


class AWNGradientPlotter(Plotter):

    def _get_layer(self, model, layer_id):
        return model.dyn_model.get_layer(layer_id)

    def on_backward(self, state: State):
        epoch = state.epoch
        model = state.model

        num_hidden_layers = model.num_hidden_layers
        layer_wise_weight_gradients = []
        layer_wise_bias_gradients = []

        # extract gradient of weights and biases
        for l in range(num_hidden_layers + 1):
            layer = self._get_layer(model, l)

            nonzero_wg = layer.weight.grad.detach().cpu().reshape(-1)
            nonzero_wg = nonzero_wg[nonzero_wg != 0.0]
            layer_wise_weight_gradients.append(nonzero_wg)

            nonzero_bg = layer.bias.grad.detach().cpu().reshape(-1)
            nonzero_bg = nonzero_bg[nonzero_bg != 0.0]
            layer_wise_bias_gradients.append(nonzero_bg)

        if "gradients" not in self.stored_metrics:
            self.stored_metrics["gradients"] = []
            assert epoch == 0

        if len(self.stored_metrics["gradients"]) <= epoch:
            # print('epoch ', epoch)

            # new epoch, append new empty list to be filled with gradients
            # from minibatches (see "else" branch)
            self.stored_metrics["gradients"].append([])
            self.stored_metrics["gradients"][epoch] = {
                "layer_wise_weight_gradients": layer_wise_weight_gradients,
                "layer_wise_bias_gradients": layer_wise_bias_gradients,
            }
        else:
            # concatenate batch of gradients with previous ones
            lwg = self.stored_metrics["gradients"][epoch][
                "layer_wise_weight_gradients"
            ]
            lbg = self.stored_metrics["gradients"][epoch][
                "layer_wise_bias_gradients"
            ]

            # concatenate, for each layer, the gradients for this minibatch
            new_layer_wise_weight_gradients = []
            new_layer_wise_bias_gradients = []
            for l in range(num_hidden_layers + 1):
                # print(lwg[l].shape, layer_wise_weight_gradients[l].shape)

                lwg_new = torch.cat(
                    (lwg[l], layer_wise_weight_gradients[l]), dim=0
                )
                new_layer_wise_weight_gradients.append(lwg_new)
                lbg_new = torch.cat(
                    (lbg[l], layer_wise_bias_gradients[l]), dim=0
                )
                new_layer_wise_bias_gradients.append(lbg_new)

            self.stored_metrics["gradients"][epoch] = {
                "layer_wise_weight_gradients": new_layer_wise_weight_gradients,
                "layer_wise_bias_gradients": new_layer_wise_bias_gradients,
            }

    def on_training_epoch_end(self, state: State):
        epoch = state.epoch
        model = state.model
        num_hidden_layers = model.num_hidden_layers

        # compute gradient stats for this epoch
        gradient_values = self.stored_metrics["gradients"][epoch]

        if "layer_wise_weight_gradients" in gradient_values:
            lwg = gradient_values["layer_wise_weight_gradients"]
            lbg = gradient_values["layer_wise_bias_gradients"]

            if epoch % 100 == 0:
                plt.figure()
                for l in range(num_hidden_layers + 1):
                    sns.kdeplot(lwg[l], label=f"W Matrix {l}", alpha=0.5)
                plt.title(f"Weight Gradient Distribution Epoch {epoch+1}")
                plt.legend()
                plt.tight_layout()
                plt.savefig(
                    Path(self.exp_path, f"w_grad_kde_epoch_{epoch}.pdf")
                )
                plt.close()

                plt.figure()
                for l in range(num_hidden_layers + 1):
                    sns.kdeplot(lbg[l], label=f"W Matrix {l + 1}", alpha=0.5)
                plt.title(f"Bias Gradient Distribution Epoch {epoch+1}")
                plt.legend()
                plt.tight_layout()
                plt.savefig(
                    Path(self.exp_path, f"b_grad_kde_epoch_{epoch}.pdf")
                )
                plt.close()

    def on_epoch_end(self, state: State):
        epoch = state.epoch
        model = state.model
        num_hidden_layers = model.num_hidden_layers

        # compute gradient stats for this epoch
        gradient_values = self.stored_metrics["gradients"][epoch]

        if "layer_wise_weight_gradients" in gradient_values:
            lwg = gradient_values["layer_wise_weight_gradients"]
            lbg = gradient_values["layer_wise_bias_gradients"]

            weight_stats = [
                (lwg[l].mean().item(), lwg[l].std().item())
                for l in range(num_hidden_layers + 1)
            ]
            bias_stats = [
                (lbg[l].mean().item(), lbg[l].std().item())
                for l in range(num_hidden_layers + 1)
            ]

            self.stored_metrics["gradients"][epoch] = {
                "layer_wise_weight_gradients_stats": weight_stats,
                "layer_wise_bias_gradients_stats": bias_stats,
            }
            # self.stored_metrics['gradients'][epoch]['layer_wise_weight_gradients_stats'] = weight_stats
            # self.stored_metrics['gradients'][epoch]['layer_wise_bias_gradients_stats'] = bias_stats

        super().on_epoch_end(state)


class MLPGradientPlotter(AWNGradientPlotter):

    def _get_layer(self, model, layer_id):
        return model.get_layer(layer_id)


class MiniBatchPlotter(Plotter):

    def __init__(
        self, exp_path: str, store_on_disk: bool = False, **kwargs: dict
    ):
        super().__init__(exp_path, store_on_disk, **kwargs)
        self.training_batch_counter = 0.0
        self.validation_batch_counter = 0.0
        self.test_batch_counter = 0.0

    def on_epoch_end(self, state: State):
        pass

    def on_training_batch_end(self, state: State):
        assert state.set == TRAINING
        set = TRAINING

        for k, v in state.batch_loss.items():
            loss_scalars = {}
            # Remove training/validation/test prefix (coupling with Engine)
            loss_name = k
            loss_scalars[f"{set}"] = v.detach().cpu()
            self.writer.add_scalars(
                loss_name, loss_scalars, self.training_batch_counter
            )

        if state.batch_score is not None:
            for k, v in state.batch_score.items():
                loss_scalars = {}
                # Remove training/validation/test prefix (coupling with Engine)
                loss_name = k
                loss_scalars[f"{set}"] = v.detach().cpu()
                self.writer.add_scalars(
                    loss_name, loss_scalars, self.training_batch_counter
                )

        if isinstance(state.model, AWN):
            variational_Ws = state.model.variational_Ws

            # Plot one graph with the different widths
            qW_probs = [w.compute_probability_vector() for w in variational_Ws]
            widths = [p.shape[0] for p in qW_probs]

            widths_dict = {
                f"width_{i+1}": widths[i] for i in range(len(widths))
            }
            self.writer.add_scalars(
                f"Width", widths_dict, self.training_batch_counter
            )

            if "model_widths" not in self.stored_metrics:
                self.stored_metrics["model_widths"] = [widths]
            else:
                self.stored_metrics["model_widths"].append(widths)

            if self.store_on_disk:
                try:
                    torch.save(self.stored_metrics, self.stored_metrics_path)
                except RuntimeError as e:
                    print(e)

        self.training_batch_counter += 1

    def on_eval_batch_end(self, state: State):
        if state.set == TRAINING:
            set = TRAINING
            bc = self.training_batch_counter
        if state.set == VALIDATION:
            set = VALIDATION
            bc = self.validation_batch_counter
        elif state.set == TEST:
            set = TEST
            bc = self.test_batch_counter

        for k, v in state.batch_loss.items():
            loss_scalars = {}
            # Remove training/validation/test prefix (coupling with Engine)
            loss_name = k
            loss_scalars[f"{set}"] = v.detach().cpu()

            self.writer.add_scalars(loss_name, loss_scalars, bc)

        if state.batch_score is not None:
            for k, v in state.batch_score.items():
                loss_scalars = {}
                # Remove training/validation/test prefix (coupling with Engine)
                loss_name = k
                loss_scalars[f"{set}"] = v.detach().cpu()

                self.writer.add_scalars(loss_name, loss_scalars, bc)

        if state.set == VALIDATION:
            self.validation_batch_counter += 1
        elif state.set == TEST:
            self.test_batch_counter += 1

    def on_fit_end(self, state: State):
        """
        Frees resources by closing the Tensorboard writer

        Args:
            state (:class:`~training.event.state.State`):
                object holding training information
        """
        self.writer.close()
