from functools import cached_property, lru_cache
import nninfo
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib import patches
from external.SxPID.sxpid import SxPID
import numpy as np
import functools


class Plot:
    """
    Abstract class that every other plot class inherits from.

    Manages standards for saving and loading etc.
    """

    def __init__(self):
        pass

    @staticmethod
    def performance_plot(epochs, perf_dict, show_plot, save_plot, save_dir, ax=None):

        if ax is None:
            _, ax = plt.subplots()

        for dataset_name, sub_dict in perf_dict.items():
            for measure_name, values in sub_dict.items():
                ax.semilogx(epochs, values, label=dataset_name + ": " + measure_name)
        ax.legend()
        ax.set_xlabel("Epochs")
        if save_plot:
            plt.savefig(save_dir + "performance_plot.pdf")
            plt.savefig(save_dir + "performance_plot.png")
        if show_plot:
            plt.show()

    @staticmethod
    def plot_loss_accuracy(performance, ax, ax2, quantization=-1, max_log_epoch=5):
        """
        Plot loss and accuracy.
        """

        nninfo.plot.format_figure_broken_axis(ax, max_exp=max_log_epoch)

        dataset_names = {'train': 'Train Set', 'test': 'Test Set'}
        for dataset in ['train', 'test']:

            performance_filtered = performance[(performance.rounding_point == 'center_saturating') & (performance.n_quantization_levels == quantization) & (performance.dataset_name == f'full_set/{dataset}')][['epoch', 'acc', 'loss']]
            
            
            # Plot accuracy
            nninfo.plot.plot_mean_and_interval(performance_filtered, ax, x='epoch', y='acc', label=f'{dataset_names[dataset]} Accuracy', lw=1)
            ax.set_xlabel("Epochs")
            ax.set_ylabel("Accuracy")
            ax.plot([], label=f'{dataset_names[dataset]} Loss')


            # Plot loss
            ax2.plot([]) # advance color cycle
            nninfo.plot.plot_mean_and_interval(performance_filtered, ax2, x='epoch', y='loss', lw=1)
            ax2.set_xlabel("Epochs")
            ax2.set_ylabel("Loss")

            max_epoch = performance['epoch'].max()

            print(f'Avg. {dataset_names[dataset]} accuracy at epoch {max_epoch}', 
                performance_filtered[performance_filtered['epoch'] == max_epoch]['acc'].mean(),
                '+-',
                performance_filtered[performance_filtered['epoch'] == max_epoch]['acc'].std(ddof=1))

class PIDPlot(Plot):
    def __init__(self):
        super(PIDPlot, self).__init__()
        self.plt = None

    def plot(
        self,
        df,
        x_column="epoch_id",
        y_column="average_pid",
        scatter=False,
        stackplot=False,
        unpack_y_dict=True,
        **kwargs
    ):
        """
        Standard Function that is called when plotting PID data.

        Args:
            measurement_df (pandas dataframe): The measurements
                you want to visualize
        """

        x_df, y_df = self.get_x_y(
            df, x_column=x_column, y_column=y_column, unpack_y_dict=unpack_y_dict
        )

        if self.plt is None:
            plt.figure(**kwargs)
            self.plt = plt

        if isinstance(y_df, pd.Series):
            y_df = y_df.to_frame(name=y_column)

        if stackplot:
            self.plt.stackplot(x_df, y_df.T, labels=y_df.columns)
        else:
            for name in y_df.columns:
                if scatter:
                    self.plt.scatter(x_df, y_df[name], label=name)
                else:
                    self.plt.plot(x_df, y_df[name], label=name)

        self.plt.xlabel(x_column)
        self.plt.ylabel("Information Content (bits)")

        return self.plt

    @staticmethod
    def plot_bivariate_pid_plane(
        data,
        ax=None,
        runs=None,
        show_individual_runs=False,
        show_MI=True,
        filter={},
        title="Bivariate PID information plane",
    ):
        """
        Args:
            data: String - measurement_dir
                dataframe - data
            runs: None  - average over all runs
                int   - only one run
                [int] - average over selection of runs
        """

        # Load data
        if isinstance(data, str):
            mm = nninfo.file_io.MeasurementManager(data)
            df = mm.load(tuple_as_str=True)
        elif isinstance(data, pd.DataFrame):
            df = data
        else:
            raise NotImplementedError
        
        #Flatten binning_params
        df = pd.concat([df.drop(['binning_params'], axis=1), df['binning_params'].apply(pd.Series)], axis=1)

        # Make T hashable by converting lists to X or Y
        df["T"] = df["T"].apply(lambda l: "X" if "X" in "".join(l) else "Y")

        df["S1"] = df["S1"].apply("".join)
        df["S2"] = df["S2"].apply("".join)

        for (key, value) in filter.items():
            df = df[df[key] == value]

        # Filter out runs
        if not runs is None:
            if isinstance(runs, int):
                runs = [runs]

            df = df[
                np.any(
                    np.array(df.run_id)[:, np.newaxis] == np.array(runs)[np.newaxis],
                    axis=1,
                )
            ]

        df = df[df['chapter_id'] < 32]

        if(len(df) == 0):
            print("No data left after filtering.")
            return

        # Unpack average_pid
        df = pd.concat(
            [df.drop(["average_pid"], axis=1), df["average_pid"].apply(pd.Series)],
            axis=1,
        )

        new_cols = [
            "((1,), (2,))",
            "((1,),)",
            "((2,),)",
            "((1, 2),)",
        ]

        pid = pd.pivot_table(
            df, values=new_cols, index=["epoch_id", "run_id"], columns="T"
        )
        pid.columns.names = ["atoms", "T"]

        X = pid.xs("X", level=1, axis=1)
        Y = pid.xs("Y", level=1, axis=1)

        # Compute averages
        avg_pid = pid.groupby("epoch_id").mean()
        avg_X = avg_pid.xs("X", level=1, axis=1)
        avg_Y = avg_pid.xs("Y", level=1, axis=1)

        # Compute mutual information
        MI = pid.stack("atoms").groupby(["epoch_id", "run_id"]).sum()
        MI_X = MI.xs("X", axis=1)
        MI_Y = MI.xs("Y", axis=1)

        # Compute avg mutual information
        avg_MI = MI.groupby("epoch_id").mean()
        avg_MI_X = avg_MI.xs("X", axis=1)
        avg_MI_Y = avg_MI.xs("Y", axis=1)

        # Manually cycle colors
        prop_cycle = plt.rcParams["axes.prop_cycle"]
        colors = prop_cycle.by_key()["color"]

        if ax is None:
            _, ax = plt.subplots(figsize=(10, 6))

        # Plot mean values
        for i, atom in enumerate(new_cols):

            ax.plot(
                avg_X[atom],
                avg_Y[atom],
                "x-",
                color=colors[i],
                label=atom,
                markevery=3,
                zorder=2,
            )
            ax.plot(
                [avg_X[atom].to_numpy()[-1]],
                [avg_Y[atom].to_numpy()[-1]],
                "ok",
                markersize=4,
                zorder=3,
            )

        if show_MI:
            ax.plot(avg_MI_X, avg_MI_Y, "kx-", label="MI", markevery=3, zorder=2)
            ax.plot(
                avg_MI_X.to_numpy()[-1], avg_MI_Y.to_numpy()[-1], "ok", label="Endpoint"
            )
        else:
            ax.plot([], [], "ok", label="Endpoint")  # For legend only

        ax.autoscale(False)

        # Plot individual runs
        if show_individual_runs:
            for i, atom in enumerate(new_cols):

                for (_, x_run), (_, y_run) in zip(
                    X[atom].groupby("run_id"), Y[atom].groupby("run_id")
                ):
                    ax.plot(
                        x_run.to_numpy(),
                        y_run.to_numpy(),
                        color=colors[i],
                        alpha=0.2,
                        zorder=1,
                    )
            if show_MI:
                for (_, x_run), (_, y_run) in zip(
                    MI_X.groupby("run_id"), MI_Y.groupby("run_id")
                ):
                    ax.plot(
                        x_run.to_numpy(), y_run.to_numpy(), "k", alpha=0.2, zorder=1
                    )

        ax.set_title(title)
        ax.set_xlabel("$\pi(X: L1_1, L1_2)$ (bits)")
        ax.set_ylabel("$\pi(Y: L1_1, L1_2)$ (bits)")

        ax.legend(title="PID atoms", loc="lower right")
        ax.xaxis.set_major_locator(plt.MultipleLocator(0.2))
        ax.yaxis.set_major_locator(plt.MultipleLocator(0.2))

        ax.grid()

    def scatter(self, df, x_column="epoch_id", y_column="average_pid"):
        """
        Standard Function that is called when plotting PID data as scatter plot.

        Args:
            measurement_df (pandas dataframe): The measurements
                you want to visualize
        """
        return self.plot(df, x_column, y_column, style="scatter")

    @staticmethod
    def get_x_y(
        df, x_column="epoch_id", y_column="average_pid", unpack_y_dict=True, sort=True
    ):
        if sort:
            temp_df = df.sort_values(x_column)
        else:
            temp_df = df
        x_df = temp_df[x_column]
        y_df = temp_df[y_column]
        # unpacking the dictionary into columns
        if unpack_y_dict:
            y_df = y_df.apply(pd.Series)
        return x_df, y_df

    @staticmethod
    def plot_backbone_atoms(df, ax, n=4, target='Y', dataset='full_set/train', layer=2, quantile_level=0.95):
        df = df[:]

        df.S1 = df.S1.map(lambda a: str(a))
        df.S2 = df.S2.map(lambda a: str(a))
        df['T'] = df['T'].map(lambda a: str(a))

        df = df[df['S1'].str.contains(f'L{layer}')]

        df = df[df['dataset_name'] == dataset]
        # informative part of PID with Y is equal to PID with X!
        df = df[df['T'].str.contains('Y')]

        if target == 'Y':
            df = pd.concat(
                [df.drop(["average_pid"], axis=1),
                 df["average_pid"].apply(pd.Series)],
                axis=1,
            )
        elif target == 'X':
            df = pd.concat(
                [df.drop(["informative_pid"], axis=1),
                 df["informative_pid"].apply(pd.Series)],
                axis=1,
            )

        df = df.set_index(['epoch_id', 'run_id'])
        df = df[[str(atom) for atom in SxPID.load_achains(n)]]

        backbones = np.array([PIDPlot.backbone_from_atom(atom, n)
                              for atom in df.columns])
        df_backbone = pd.DataFrame()

        for i in range(1, n+1):
            df_backbone[i] = df[df.columns[backbones == i]].sum(axis=1)

        df_backbone_median = df_backbone.groupby('epoch_id').median()
        df_backbone_high = df_backbone.groupby(
            'epoch_id').quantile(.5 - quantile_level / 2)
        df_backbone_low = df_backbone.groupby(
            'epoch_id').quantile(.5 + quantile_level / 2)

        df_backbone_extended = df_backbone[df.index.get_level_values(
            'run_id') == 0].groupby('epoch_id').mean()

        #df_backbone['total'] = df_backbone.sum(axis=1)

        ax.set_xscale('symlog', linthresh=1, linscale=.6)

        for backbone in range(1, n+1):
            #sns.lineplot(x=df_backbone_median.index, y=df_backbone_median[backbone], label=backbone)
            line = ax.plot(df_backbone_median[df_backbone_median.index <= 10000][backbone],
                           zorder=1, label='$m = {}$'.format(backbone), solid_capstyle='butt')
            ax.fill_between(df_backbone_low[df_backbone_low.index <= 10000].index, df_backbone_low[df_backbone_low.index <= 10000][backbone].values,
                            df_backbone_high[df_backbone_low.index <= 10000][backbone].values, color=line[-1].get_color(), alpha=0.3, zorder=0, linewidth=0)

            line = ax.plot(df_backbone_extended[df_backbone_extended.index >= 10000][backbone],
                           zorder=-2, color=line[-1].get_color(), linestyle='-', alpha=0.3, linewidth=1)

        ax.add_patch(patches.Rectangle((10**3, 0), 10**4-10 **
                                       3, 5, color='w', linewidth=0, zorder=-1))

        ax.set_xlim(0, max(df_backbone_extended.index))

        ax.set_xticks([0] + [10**i for i in range(5)])
        ax.set_xticklabels(['$0$', '$1$', '', '$10^2$', '', '$10^4$'])

        #ax.xaxis.set_minor_locator(FixedLocator([i*10**e for i in range(10) for e in range(6)]))

        ax.set_ylim(0, 1.6 if target == 'Y' else 5)

        # Broken axis
        d = .01
        broken_x = 0.05
        breakspacing = 0.015
        ax.plot((broken_x-breakspacing*0.9, broken_x+breakspacing*0.9), (0, 0),
                color='w', transform=ax.transAxes, clip_on=False, linewidth=.85, zorder=3)
        ax.plot((broken_x-breakspacing*0.9, broken_x+breakspacing*0.9), (1, 1),
                color='w', transform=ax.transAxes, clip_on=False, linewidth=.85, zorder=3)

        kwargs = dict(transform=ax.transAxes, color='k',
                      clip_on=False, linewidth=1, zorder=4)
        ax.plot((broken_x-d-breakspacing, broken_x+d -
                 breakspacing), (-3*d, +3*d), **kwargs)
        ax.plot((broken_x-d-breakspacing, broken_x+d -
                 breakspacing), (1-3*d, 1+3*d), **kwargs)
        ax.plot((broken_x-d+breakspacing, broken_x+d +
                 breakspacing), (-3*d, +3*d), **kwargs)
        ax.plot((broken_x-d+breakspacing, broken_x+d +
                 breakspacing), (1-3*d, 1+3*d), **kwargs)

    @staticmethod
    def plot_representational_complexity(df, ax, target='Y', dataset='full_set/train', layer=2, quantile_level=0.95, use_median=False, **kwargs):
        df = df[:]

        df.S1 = df.S1.map(lambda a: str(a))
        df.S2 = df.S2.map(lambda a: str(a))
        df['T'] = df['T'].map(lambda a: str(a))

        df = df[df['dataset_name'] == dataset]
        # informative part of PID with Y is equal to PID with X!
        df = df[df['T'].str.contains('Y')]

        df = df[df['S1'].str.contains(f'L{layer}')]

        df_backbone = compute_backbone(df, target=target)

        kwargs.setdefault('label', f'$L_{{{layer}}}$')

        plot_mean_and_interval(
            df_backbone, ax, y='c', quantile_level=quantile_level, use_median=use_median, **kwargs)

        max_chapter = df['epoch_id'].max()

        format_figure_broken_axis(ax, int(np.log10(max_chapter)))
        ax.set_xlim(0, max_chapter)


def backbone_from_atom(atom):
    return min(len(a) for a in atom)


@lru_cache(5)
def get_backbones(n):
    return np.array(list(map(str, SxPID.load_achains(n)))), np.array([backbone_from_atom(atom)
                     for atom in SxPID.load_achains(n)])


def plot_mean_and_interval(df, ax, x='epoch_id', y='c', use_median=False, quantile_level=0.95, zorder_shift=0, **kwargs):
    df_backbone_center = df.groupby(x).median(
    ) if use_median else df.groupby(x).mean()
    df_backbone_high = df.groupby(x).quantile(.5 - quantile_level / 2)
    df_backbone_low = df.groupby(x).quantile(.5 + quantile_level / 2)

    kwargs.setdefault('label', '')

    line = ax.plot(df_backbone_center[y], zorder=1+zorder_shift, solid_capstyle='butt', **kwargs)
    ax.fill_between(df_backbone_low.index, df_backbone_low[y].values,
                    df_backbone_high[y].values, color=line[-1].get_color(), alpha=0.3, zorder=0+zorder_shift, linewidth=0)


def format_figure_broken_axis(ax, max_exp=4):

    ax.set_xscale('symlog', linthresh=1, linscale=.6)

    ax.set_xlim(0, 10**max_exp)
    ax.set_xticks([0] + [10**i for i in range(max_exp+1)])
    ax.set_xticklabels(['$0$', '$1$'] + ['' if i %
                                         2 == 1 else f'$10^{i}$' for i in range(1, max_exp+1)])

    # Broken axis
    d = .01
    broken_x = 0.07
    breakspacing = 0.015
    ax.plot((broken_x-breakspacing*0.9, broken_x+breakspacing*0.9), (0, 0),
            color='w', transform=ax.transAxes, clip_on=False, linewidth=.8, zorder=3)
    ax.plot((broken_x-breakspacing*0.9, broken_x+breakspacing*0.9), (1, 1),
            color='w', transform=ax.transAxes, clip_on=False, linewidth=.8, zorder=3)

    kwargs = dict(transform=ax.transAxes, color='k',
                  clip_on=False, linewidth=.8, zorder=4)
    ax.plot((broken_x-d-breakspacing, broken_x+d -
             breakspacing), (-3*d, +3*d), **kwargs)
    ax.plot((broken_x-d-breakspacing, broken_x+d -
             breakspacing), (1-3*d, 1+3*d), **kwargs)
    ax.plot((broken_x-d+breakspacing, broken_x+d +
             breakspacing), (-3*d, +3*d), **kwargs)
    ax.plot((broken_x-d+breakspacing, broken_x+d +
             breakspacing), (1-3*d, 1+3*d), **kwargs)

def get_deg_of_syn(atom, n):

    if isinstance(atom, str):
        atom = next(alpha for alpha in SxPID.load_achains(n) if str(alpha) == atom)

    return min(len(a) for a in atom)

def compute_backbone(df, target='Y'):

    n = (5 if 'S5' in df.columns else (4 if 'S4' in df.columns else 3))

    target_pid = 'average_pid' if target=='Y' else 'informative_pid'

    df_pid = df.iloc[:, df.columns.map(lambda x: x[0] if isinstance(x, tuple) else x) == target_pid] 
    df_pid.columns = df_pid.columns.map(lambda x: x[1])

    df = pd.concat([df[['epoch_id', 'run_id']], df_pid], axis=1)

    df = df.set_index(['epoch_id', 'run_id'])

    atoms, backbones = get_backbones(n)

    # Make sure columns are correctly ordered
    assert all(df.columns == atoms)

    df_backbone = pd.DataFrame()

    for i in range(1, n+1):
        df_backbone[i] = df[df.columns[backbones == i]].sum(axis=1)

    df_backbone['c'] = (1 * df_backbone[1] + 2 * df_backbone[2] + 3 * df_backbone[3] + (4 * df_backbone[4] if n >= 4  else 0)+ (
        5 * df_backbone[5] if n == 5 else 0)) / (df_backbone[1] + df_backbone[2] + df_backbone[3] + (df_backbone[4] if n >= 4 else 0) + (df_backbone[5] if n == 5 else 0))

    return df_backbone    

class FisherPlot(Plot):
    def __init__(self):
        super(FisherPlot, self).__init__()
        self.plt = None
        pass

    def plot_layer(self, matrix, layer_id):
        if isinstance(matrix, list):
            matrix = np.array(matrix)
        print(matrix.shape)
        plt.xticks(np.arange(matrix.shape[1]))
        plt.yticks(np.arange(matrix.shape[0]))
        plt.imshow(matrix)
        plt.xlabel("Neurons in Layer " + str(layer_id))
        plt.ylabel("Neurons in Layer " + str(layer_id + 1))
        plt.colorbar(shrink=0.54)
        self.plt = plt
        return plt

    def plot_network(self, weight_dict):
        pass

    @staticmethod
    def plot_layer2(x, y, layer_name):
        def neuron_id(layer_id, neuron_id, wb=None):
            if layer_id == 0:
                name = "X"
            else:
                if wb is None:
                    name = "L" + str(layer_id)
                else:
                    name = "L" + str(layer_id) + wb
            return tuple((name, (neuron_id,)))

        y = pd.DataFrame(y.loc[:, layer_name])
        _, idx, temp_wb = layer_name.split(".")
        idx = int(idx)
        if temp_wb == "weight":
            wb = "W"
        elif temp_wb == "bias":
            wb = "B"
        for name in y.columns:
            if wb == "W":
                output_neuron = y[name].apply(pd.Series)
                for ocol in output_neuron.columns:
                    input_neuron = output_neuron[ocol].apply(pd.Series)
                    for icol in input_neuron.columns:
                        label = str(neuron_id(idx, icol + 1))
                        label += " -> "
                        label += str(neuron_id(idx + 1, ocol + 1))
                        # label = str(neuron_id(idx+1, ocol+1,wb))
                        plt.plot(x, input_neuron[icol], label=label)
            else:
                output_neuron = y[name].apply(pd.Series)
                for ocol in output_neuron.columns:
                    biases = output_neuron[ocol].apply(pd.Series)
                    label = "B -> " + str(neuron_id(idx + 1, ocol + 1))
                    plt.plot(x, biases, label=label)

        plt.xscale("log")
        plt.legend()
        return plt
