from matplotlib.lines import Line2D
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
from collections import defaultdict
import os
import pickle
import bz2


class GradViz():
    def __init__(self, model, logdir, graphics=False):
        self.named_parameters = model.named_parameters()
        self.model = model
        self.graphics = graphics
        if graphics:
            self.fig = plt.figure()
            self.ax = self.fig.add_subplot(111)

        self.n_saves = 0
        self.logdir = logdir
        self.reset()

    def reset(self):
        self.gradients = defaultdict(lambda: [0., 0., 0, 0, 0])

    @torch.no_grad()
    def update(self):
        if self.graphics:
            self.draw_()

        for m_name, m in self.model.named_modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                for n, p in m.named_parameters(recurse=False):
                    if p.requires_grad and ("bias" not in n):
                        p = p.grad.detach().cpu()

                        # sum up the positive gradients
                        pos_grads = torch.clamp(p, min=0).flatten()
                        self.gradients[m_name][0] = \
                            pos_grads + \
                            self.gradients[m_name][0]
                        self.gradients[m_name][2] = \
                            (pos_grads > 0).int() + \
                            self.gradients[m_name][2]

                        # sum up the negative gradients
                        neg_grads = torch.clamp(p, max=0).flatten()
                        self.gradients[m_name][1] = \
                            neg_grads + \
                            self.gradients[m_name][1]
                        self.gradients[m_name][3] = \
                            (neg_grads < 0).int() + \
                            self.gradients[m_name][3]

                        self.gradients[m_name][4] = \
                            (p == 0).int() + \
                            self.gradients[m_name][4]

    @torch.no_grad()
    def save(self):
        save_dir = os.path.join(self.logdir, "gradients", str(self.n_saves))
        os.makedirs(save_dir, exist_ok=True)

        named_modules = {}
        for k, m in self.model.named_modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                named_modules[k] = m

        for k, gradient in self.gradients.items():
            # Plot weights
            module = named_modules[k]
            w = module.weight.clone().detach().cpu().flatten()
            gradient.append(w)

        # Save all values in compressed format
        fp = bz2.BZ2File(os.path.join(save_dir, "values.pickle.bz2"), 'wb')
        pickle.dump(dict(self.gradients), fp)
        fp.close()
        self.n_saves += 1

    def save_reset(self):
        self.save()
        self.reset()

    def draw_(self):
        '''Plots the gradients flowing through different layers in the net during
        training. Can be used for checking for possible gradient vanishing or
        exploding problems.

        Usage: Plug this function in Trainer class after loss.backwards()
        to visualize the gradient flow'''

        named_parameters = self.named_parameters
        ave_grads = []
        max_grads = []
        layers = []
        for n, p in named_parameters:
            if(p.requires_grad) and ("bias" not in n):
                layers.append(n)
                ave_grads.append(p.grad.abs().mean())
                max_grads.append(p.grad.abs().max())
        self.ax.clear()
        self.ax.bar(
            np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
        self.ax.bar(
            np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
        self.ax.hlines(0, 0, len(ave_grads)+1, lw=2, color="k")
        plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
        plt.xlim(left=0, right=len(ave_grads))
        plt.ylim(bottom=-0.001, top=0.02)  # zoom in
        plt.xlabel("Layers")
        plt.ylabel("average gradient")
        plt.title("Gradient flow")
        plt.grid(True)
        plt.legend([Line2D([0], [0], color="c", lw=4),
                    Line2D([0], [0], color="b", lw=4),
                    Line2D([0], [0], color="k", lw=4)],
                   ['max-gradient', 'mean-gradient', 'zero-gradient'])
        plt.draw()
        plt.pause(0.05)
