import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np
from scipy.signal import savgol_filter
import yaml
import seaborn as sns
import pickle
import os


class Grapher:
    def __init__(self, base_pt: str, save_pickle: bool = True, model_info: dict = None):
        self.time_name = None
        self.train_loss = []
        self.test_loss = []
        self.train_acc = []
        self.test_acc = []
        self.lr = []

        self.cosine = []
        self.weights = []
        self.dm_value_net = []
        self.dm_value_wordnet = []
        self.dm_value_glove = []
        self.mse = []
        self.mae = []
        self.structural = []
        self.cosine_glove = []
        self.mse_glove = []
        self.mae_glove = []
        self.structural_glove = []

        self.confusion_weights_prod_sum = []
        self.confusion_weights_cos = []
        self.confusion_wordnet = []
        self.confusion_glove = []

        self.weights_min = []
        self.weights_max = []

        self.model_info = model_info
        self.save_pickle = save_pickle
        self.path = self.__set_path(base_pt)

    def __set_path(self, base_pt: str):
        date = datetime.now()
        DIR = f'{date.year}-{date.month}-{date.day}+{date.hour}+{date.minute}_{self.model_info["MODEL_NAME"]}'
        self.time_name = DIR
        if not os.path.exists(f'{base_pt}/{DIR}'):
            os.makedirs(f'{base_pt}/{DIR}')
        return f'{base_pt}/{DIR}/.png'

    def add_data(self, train_data, test_data, lr, cosine, weights, dm_value_net, dm_value_wordnet, dm_value_glove,
                 mse, mae, weights_min, weights_max, structural, cosine_glove, mse_glove, mae_glove, structural_glove,
                 confusion_weights_prod_sum, confusion_weights_cos, confusion_wordnet, confusion_glove):
        self.train_loss.append(train_data[0])
        self.train_acc.append(train_data[-1])
        self.test_loss.append(test_data[0])
        self.test_acc.append(test_data[-1])
        self.cosine.append(cosine)
        self.weights.append(weights)
        self.dm_value_net.append(dm_value_net)
        self.dm_value_wordnet.append(dm_value_wordnet)
        self.dm_value_glove.append(dm_value_glove)

        self.confusion_weights_prod_sum.append(confusion_weights_prod_sum)
        self.confusion_weights_cos.append(confusion_weights_cos)
        self.confusion_wordnet.append(confusion_wordnet)
        self.confusion_glove.append(confusion_glove)

        self.mse.append(mse)
        self.mae.append(mae)
        self.weights_min.append(weights_min)
        self.weights_max.append(weights_max)
        self.structural.append(structural)
        self.cosine_glove.append(cosine_glove)
        self.mse_glove.append(mse_glove)
        self.mae_glove.append(mae_glove)
        self.structural_glove.append(structural_glove)
        self.lr.append(lr)

    def smooth_data(self, data):
        return savgol_filter(data, 11, 3, mode='nearest').tolist()

    def make_out_dict(self):
        out_dict = {'epoch_num': len(self.test_acc),
                    'train': {'loss': self.train_loss, 'acc': self.train_acc,
                              'loss_s': self.smooth_data(self.train_loss), 'acc_s': self.smooth_data(self.train_acc)},
                    'test': {'loss': self.test_loss, 'acc': self.test_acc,
                             'loss_s': self.smooth_data(self.test_loss), 'acc_s': self.smooth_data(self.test_acc)},
                    'similarity': {'cosine_wordnet': self.cosine, 'weights': self.weights,
                                   'mse_wordnet': self.mse, 'mae_wordnet': self.mae, 'weights_min': self.weights_min,
                                   'weights_max': self.weights_max, 'structural_wordnet': self.structural,
                                   'cosine_glove': self.cosine_glove, 'mse_glove': self.mse_glove,
                                   'mae_glove': self.mae_glove, 'structural_glove': self.structural_glove,
                                   'dm_network': self.dm_value_net, 'dm_wordnet': self.dm_value_wordnet,
                                   'dm_glove': self.dm_value_glove,
                                   'confusion_weight_prod_sum': self.confusion_weights_prod_sum,
                                   'confusion_weight_cos': self.confusion_weights_cos,
                                   'confusion_glove': self.confusion_glove,
                                   'confusion_wordnet': self.confusion_wordnet}}

        if self.model_info is not None:
            self.model_info.update(out_dict)
            return self.model_info
        else:
            return out_dict

    def save_data(self):
        yaml.dump(self.make_out_dict(), open(self.path.replace('.png', 'results.yml'), 'w'))

    def save_matrices(self, matrices, binary_output: bool = True):
        out_matrices = []
        for matrix in matrices[:-2]:
            eye = np.eye(matrices[0].shape[0])
            out_matrices.append(np.where(eye == 1, np.nan, matrix))

        wordnet_sim, init_wg, curr_wg, glove_sim = out_matrices
        _, _, curr_wg, _, _, conf_matrix = matrices

        if binary_output:
            binary_dict = {'wordnet_sim': wordnet_sim, 'init_wg': init_wg, 'last_wg': curr_wg,
                           'glove_sim': glove_sim, 'confusion_prod':matrices[-2], 'confusion_matrix': matrices[-1]}
            pickle.dump(binary_dict, open(self.path.replace('.png', 'binary_matrices.pkl'), 'wb'))

            with open(self.path.replace('.png', 'binary_matrices_after_epoch_weights_NCSM.pkl'), 'ab') as f:
                pickle.dump(curr_wg, f)

            with open(self.path.replace('.png', 'binary_matrices_after_epoch_confusion_CCSM.pkl'), 'ab') as f:
                pickle.dump(conf_matrix, f)

        self.save_data()

        sns.heatmap(wordnet_sim)
        plt.title('Word matrix')
        plt.savefig(self.path.replace('.png', 'WM.png'), bbox_inches='tight')
        plt.clf()

        sns.heatmap(init_wg)
        plt.title('Initial weights')
        plt.savefig(self.path.replace('.png', 'IW.png'), bbox_inches='tight')
        plt.clf()

        sns.heatmap(curr_wg)
        plt.title('Output weights')
        plt.savefig(self.path.replace('.png', 'OW.png'), bbox_inches='tight')
        plt.clf()

        sns.heatmap(glove_sim)
        plt.title('GLOVE matrix')
        plt.savefig(self.path.replace('.png', 'GM.png'), bbox_inches='tight')
        plt.clf()

        sns.heatmap(matrices[-2])
        plt.title('ConfMat - Weights prod')
        plt.savefig(self.path.replace('.png', 'CM_W.png'), bbox_inches='tight')
        plt.clf()

        sns.heatmap(matrices[-1])
        plt.title('Confusion Matrix')
        plt.savefig(self.path.replace('.png', 'CM.png'), bbox_inches='tight')
        plt.clf()

    def make_graph(self):
        epochs = list(range(len(self.test_acc)))
        colors = ['steelblue', 'limegreen']
        fig, ax = plt.subplots(1, 3, figsize=(22, 4), gridspec_kw={'wspace': 0.1, 'hspace': 0.1}, sharey=False)
        ax[0].plot(epochs, self.train_loss, color=colors[0], alpha=0.2)
        ax[0].plot(epochs, self.smooth_data(self.train_loss), color=colors[0], label='train')
        ax[0].plot(epochs, self.test_loss, color=colors[1], alpha=0.2)
        ax[0].plot(epochs, self.smooth_data(self.test_loss), color=colors[1], label='test')
        ax[0].set_xlabel('Epoch')
        ax[0].grid()
        ax[0].set_title('Train/test loss')
        ax[0].legend()
        ax[1].plot(epochs, self.train_acc, color=colors[0], alpha=0.2)
        ax[1].plot(epochs, self.smooth_data(self.train_acc), color=colors[0], label='train')
        ax[1].plot(epochs, self.test_acc, color=colors[1], alpha=0.2)
        ax[1].plot(epochs, self.smooth_data(self.test_acc), color=colors[1], label='test')
        ax[1].set_xlabel('Epoch')
        ax[1].grid()
        ax[1].set_title('Train/test accuracy')
        ax[1].legend()
        ax[2].plot(epochs, self.lr, color=colors[0], label='LR')
        ax[2].grid()
        ax[2].set_xlabel('Epoch')
        ax[2].set_title('LR for Epoch')
        ax[2].legend()
        plt.savefig(self.path, bbox_inches='tight')

        if self.save_pickle:
            self.save_data()
