from typing import Optional, List

import matplotlib.pyplot as plt
import numpy as np
import wandb
from lightning import LightningModule

from common.utils import create_dirs_if_not_exist, save_in_pickle
from metrics.plot_hist_timeseries import get_plot_hist, get_plot_ts


class Logger(LightningModule):

    def __init__(
        self, path_storage: str, save_plots_in_file: bool, file_extension: str, save_synthetic_data: bool,
        n_samples_evaluation: int, save_plots_in_wandb: bool, feature_names: List[str], max_lag: int
    ) -> None:
        super().__init__()

        self.path_storage = path_storage
        self.save_plots_in_file = save_plots_in_file
        self.file_extension = file_extension
        self.save_synthetic_data = save_synthetic_data

        self.n_samples_evaluation = n_samples_evaluation
        self.save_plots_in_wandb = save_plots_in_wandb

        self.feature_names = feature_names
        self.n_features: int = len(feature_names)
        self.max_lag = max_lag

        self.dir_to_save_ts_plots: str = f"{self.path_storage}/plots/"

    def log_plot(
        self, sample_idx: int, ts_predicted: np.ndarray, plot_type: str, folder_name: str,
        timestep: Optional[int] = None, ts_real: Optional[np.ndarray] = None,
        graph: Optional[np.ndarray] = None, val_matrix: Optional[np.ndarray] = None,
        vmin_edges: Optional[float] = .0, vmax_edges: Optional[float] = .0,
    ) -> None:
        # if plot_type == 'timeseries':
        #     ts_predicted.shape = [seq_len-max_lag, n_features]
        # elif plot_type == 'hist':
        #     ts_real.shape = ts_predicted.shape = [seq_len, n_features]

        if self.save_plots_in_wandb or self.save_plots_in_file:

            title_wandb = f'{folder_name}/Epoch:{self.current_epoch}'
            dir_to_save = f'{self.dir_to_save_ts_plots}/{title_wandb}'
            title_plot = f'Sample:{sample_idx}' + (f'-Timestep:{timestep}' if timestep else '')

            fig = get_plot_ts(ts_predicted, graph, val_matrix, vmin_edges, vmax_edges, title_plot, self.feature_names) \
                if plot_type == 'timeseries' else get_plot_hist(ts_real, ts_predicted, title_plot, self.feature_names)
            fig.tight_layout()

            if self.save_plots_in_wandb:
                self.logger.experiment.log({title_wandb: wandb.Image(fig)})

            if self.save_plots_in_file:
                create_dirs_if_not_exist(dir_to_save)
                plt.savefig(f"{dir_to_save}/{title_plot}.{self.file_extension}")

            plt.close(fig)

    def evaluate_prediction(self, real: np.ndarray, pred: np.ndarray, t: np.ndarray) -> None:
        # real.shape = pred.shape = [n_samples_evaluation, seq_len, n_features]
        # t.shape = [n_samples_evaluation]

        # Plot prediction
        for i, (real_, pred_, t_) in enumerate(zip(real, pred, t)):
            self.log_plot(
                sample_idx=i, ts_predicted=pred_, plot_type="hist", folder_name="noise", timestep=t_, ts_real=real_,
            )

    def get_valmatrix_graph(self, coefficients: np.ndarray) -> (np.ndarray, np.ndarray):
        # coefficients.shape = [n_samples_evaluation, n_features, n_features*max_lag, seq_len-max_lag]
        c = np.quantile(coefficients, q=.95, axis=-1)
        c = c.reshape(self.n_samples_evaluation, self.n_features, self.n_features, self.max_lag).transpose((0, 2, 1, 3))
        c = np.flip(c, axis=-1)
        z = np.zeros((self.n_samples_evaluation, self.n_features, self.n_features, 1))
        val_matrix = np.concatenate([z, c], axis=-1)
        # val_matrix.shape = [n_samples_evaluation, n_features, n_features, max_lag + 1]

        graph = np.full_like(val_matrix, '', dtype='<U3')
        # graph.shape = [n_samples_evaluation, n_features, n_features, max_lag + 1]

        for n in range(self.n_samples_evaluation):
            for f1 in range(self.n_features):
                for f2 in range(self.n_features):
                    for l in range(self.max_lag + 1):
                        if val_matrix[n, f1, f2, l] > 0.:
                            graph[n, f1, f2, l] = "-->"

        return graph, val_matrix

    def evaluate_synthetic(self, synthetic: np.ndarray, coefficients: np.ndarray) -> None:
        # synthetic.shape = [n_samples_evaluation, seq_len-max_lag, n_features]
        # coefficients.shape = [n_samples_evaluation, n_features, n_features*max_lag, seq_len-max_lag]

        # graph, val_matrix = self.get_valmatrix_graph(coefficients)
        # graph.shape = val_matrix.shape = [n_samples_evaluation, n_features, n_features, max_lag+1]

        # vmin_edges, vmax_edges = val_matrix.min(), val_matrix.max()

        # Plot synthetic
        # for i, (synthetic_, graph_, val_matrix_) in enumerate(zip(synthetic, graph, val_matrix)):
        #     self.log_plot(
        #         sample_idx=i, ts_predicted=synthetic_, plot_type="timeseries", folder_name="synthetic",
        #         graph=graph_, val_matrix=val_matrix_, vmin_edges=vmin_edges, vmax_edges=vmax_edges
        #     )

        if self.save_synthetic_data:
            d = dict(
                synthetic=synthetic, coefficients=coefficients,
                # graph=graph, val_matrix=val_matrix
            )
            save_in_pickle(f"{self.path_storage}/synthetic_data", f'epoch_{self.current_epoch}', d)
