from moduleloader import outervar
import moduleloader as mf
import torch
import numpy as np
import scipy.stats as st
import pandas as pd
from lola_modules.losspath.path import init_new_direction, slice_landscape, rescorr, slice_resout,mag, cache_net,walk_step,fitpl_pow,pow
import torch
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.collections import LineCollection
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib as mpl
import numpy as np

from functools import partial

def init(state,event):
    state["criterion"] = event.init_loss(reduction="none")

    # generate new copy of net
    state["net_cached"] = event.send_net_to_device(event.init_net())

def rescorr_walk(state, event, net=outervar, inputs=outervar):
    '''From the current point in weight space, take MULTIPLE slices in random directions
    of the neuronal surface of resnet summands and measure their correlation / SNRI.'''
    # Run every state["step_every"] steps
    if state.all["step"] % state["step_every"] == 0:
        if state["snrmode"] == "hplp":
            snri_per_block = []
            snr_per_block = []
            for i in range(state["numdir"]):
                snri_per_block_, snr_per_block_ = rescorr(state, event, net=net, inputs=inputs)
                snri_per_block.append(snri_per_block_)
                snr_per_block.append(snr_per_block_)
                event.init_new_direction()
            snri_per_block = torch.stack(snri_per_block, axis=0).mean(axis=0)
            snr_per_block = torch.stack(snr_per_block, axis=0).mean(axis=0)

            event.optional.plot_scalar2d(snri_per_block, title="Average SNRI per ResBlock (" + state.all["log.dir"]+")")
            event.optional.plot_scalar2d(torch.log(snr_per_block), title="Average log(SNR) of BN(s1) per ResBlock")
            event.optional.plot_scalar2d(torch.log(snr_per_block*snri_per_block), title="Average log(SNR) of BN(s1+s2) per ResBlock")
            event.optional.plot_scalar(snri_per_block.mean().cpu(), title="Average SNRI")
            event.optional.plot_scalar(snr_per_block.mean().cpu(), title="Average SNR")


        elif state["snrmode"] == "full":
            confusion_per_block = []
            mean_corrs_per_block = []

            for i in range(state["numdir"]):
                corrs_per_block_, mean_corrs_per_block_ = rescorr(state, event, net=net, inputs=inputs)
                confusion_per_block.append(corrs_per_block_)
                mean_corrs_per_block.append(mean_corrs_per_block_)
                event.init_new_direction()

            confusion_per_block = torch.stack(confusion_per_block, axis=0)
            mean_corrs_per_block = torch.stack(mean_corrs_per_block, axis=0)
            mean_confusion = confusion_per_block.mean(axis=(0,1)).cpu()
            mean_corrs = mean_corrs_per_block.mean(axis=(0,1)).cpu()
            event.optional.plot_static_hm(mean_confusion, title="Mean SNRI of (X) resp. (Y)", xlabel="x frequency bin", ylabel="y frequency bin")
            event.optional.plot_scalar2d(confusion_per_block[:,:,-1 ,0].mean(0), title="Mean SNRI of highest / lowest bin(" + state.all["log.dir"]+")", xlabel="Batch", ylabel="SNRI")

            if state["log"]:
                try:
                    df_corrs = pd.read_csv("./logs/correlation.csv",index_col=0)
                except FileNotFoundError:
                    df_corrs = pd.DataFrame()
                df_corrs = df_corrs.join(pd.Series(mean_corrs, name=state.all["log.dir"]+"_step_"+str(state.all["step"])),how="outer",rsuffix="_")
                df_corrs.index.name = "freq_bin"
                df_corrs.to_csv("./logs/correlation.csv")

                try:
                    df_corrs = pd.read_csv("./logs/snri.csv",index_col=0)
                except FileNotFoundError:
                    df_corrs = pd.DataFrame()
                df_corrs = df_corrs.join(pd.Series(confusion_per_block.mean(0).cpu().numpy()[:,-1,0], name=state.all["log.dir"]+"_step_"+str(state.all["step"])),how="outer",rsuffix="_")
                df_corrs.index.name = "resblock"
                df_corrs.to_csv("./logs/snri_per_layer.csv")


def fourier_loss_walk(state, event, net=outervar, inputs=outervar, labels=outervar):
    '''From the current point in weight space, take MULTIPLE slices of the loss surface in random directions.'''
    if state.all["step"] % state["step_every"] == 0:
        a_per_path = []
        r_per_path = []
        paths = []
        pows = []
        a_per_init = []
        r_per_init = []

        for i in range(state["num_init"]):
            # in case we want to loop over multiple initializations (only makes sense when not training)
            if state["num_init"] > 1:
                net = event.init_net()
                net = event.send_net_to_device(net)

            for j in range(state["numdir"]):
                path, pow, a_per_path_, r_per_path_ = slice_landscape(state,event,net=net, inputs=inputs, labels=labels)
                event.init_new_direction()

                paths.append(path)
                pows.append(pow)
                a_per_path.append(a_per_path_)
                r_per_path.append(r_per_path_)
            event.optional.plot_scalar(np.mean(a_per_path),  title="Average Cross section smoothness")

            mean_pow_per_init = np.mean(pows[i*state["numdir"]:(i+1)*state["numdir"]-1],0)
            a_mean_per_run,_,r_mean_per_run = fitpl_pow(state, event, mean_pow_per_init)
            event.optional.plot_scalar(mean_pow_per_init[1:], list(range(len(mean_pow_per_init)-1)),
                                       title="Mean Power spectrum")

            a_per_init.append(a_mean_per_run)
            r_per_init.append(r_mean_per_run)

        if state["log"]:
            # log smoothness stats
            try:
                df = pd.read_csv("./logs/smoothness_stats.csv", index_col=0)
            except FileNotFoundError:
                df = pd.DataFrame()
            df = df.append(pd.Series([np.mean(a_per_path),np.std(a_per_path), np.mean(r_per_path),np.std(r_per_path),np.mean(a_per_init),np.std(a_per_init),np.mean(r_per_init),np.std(r_per_init)],
                name=state.all["log.dir"]+"_step_"+str(state.all["step"]),
                index=["a_per_path_avg", "a_per_path_std","r_per_path_avg","r_per_path_std", "a_per_init_avg","a_per_init_std","r_per_init_avg","r_per_init_std"]))
            df.to_csv("./logs/smoothness_stats.csv")

            # log median path with power spectrum

            argmean = lambda data:np.argsort(data)[len(data)//2]

            med_path_a = paths[argmean(a_per_init)].cpu().numpy()
            med_path_r = paths[argmean(r_per_init)].cpu().numpy()
            pow_med_path_a = pows[argmean(a_per_init)][1:]
            pow_med_path_r = pows[argmean(r_per_init)][1:]

            try:
                df = pd.read_csv("./logs/med_path.csv", index_col=0)
            except FileNotFoundError:
                df = pd.DataFrame()
                df.index.name = 'step'
            df_ = pd.DataFrame([med_path_a, np.sqrt(pow_med_path_a), np.sqrt(med_path_r),
                          pow_med_path_r],
                         index=["med_a_" + state.all["log.dir"], "mag_a_" + state.all["log.dir"],
                                "med_r_" + state.all["log.dir"], "mag_r_" + state.all["log.dir"]]).T
            df_.index.name = 'step'
            df = df.join(df_,how="outer",rsuffix="_")
            df.to_csv("./logs/med_path.csv")

def linear_to_conv_to_resnet(state, event, net=outervar,inputs=outervar, labels=outervar):
    '''Produces path morph from linear->relu->resnet'''
    # HAS TO BE LOADED WITH RESNET SHORT + LRELU, ACTIVATE RESIDUAL MODIFYER

    def plot_lines_3d(paths,a_s,colorlist=[(0,0,1),(1,0,0)],cbar_on=False):
        #plt.style.use('dark_background')
        fig = plt.figure()
        ax = fig.gca(projection='3d')

        # plot data as linecolleciton
        xs = np.arange(paths.shape[1])
        verts = []
        zs = np.arange(paths.shape[0])
        sm = plt.cm.ScalarMappable(cmap=cm.autumn, norm=plt.Normalize(vmin=a_s.min(), vmax=a_s.max()))
        for path in paths:
            verts.append(list(zip(xs, path)))
        a_s_rescaled = (a_s-a_s.min())/np.ptp(a_s)
        #colors = [cm.autumn(a) for a in a_s_rescaled]
        colormap = mpl.colors.LinearSegmentedColormap.from_list(
            "mylist", colorlist, N=len(paths))
        colors = [colormap(a) for a in np.linspace(0,1,len(paths))]
        poly = LineCollection(verts,colors =colors)
        poly.set_alpha(0.7)
        ax.add_collection3d(poly, zs=zs, zdir='y')

        # labels
        ax.set_xlabel('X')
        ax.set_xlim3d(0, paths.shape[1])
        ax.set_ylabel('Y')
        ax.set_ylim3d(-1, paths.shape[0])
        ax.set_zlabel('Z')
        ax.set_zlim3d(paths.min(), paths.max())
        ax.view_init(50,320)

        # set grid invisible
        ax.xaxis.set_pane_color((0, 0, 0, 0))
        ax.yaxis.set_pane_color((0, 0, 0, 0))
        ax.zaxis.set_pane_color((0, 0, 0, 0))
        ax.xaxis._axinfo["grid"]['color'] = (0,0,0,0)
        ax.yaxis._axinfo["grid"]['color'] = (0,0,0,0)
        ax.zaxis._axinfo["grid"]['color'] = (0,0,0,0)

        # set cbar
        if cbar_on:
            cbar = fig.colorbar(sm, shrink=0.25, aspect=5)
            cbar.set_label("Path smoothness (a)")

        ax._axis3don = False

        plt.show()

    paths_lintoconv = []
    pows_lintoconv =[]
    a_s_lintoconv = []

    paths_convtores = []
    pows_convtores = []
    a_s_convtores = []

    # init resnet56 with lrelu
    # modify residual amount

    # set resnet to noshort
    state["nu"] = 0

    # morph from linear to convnet
    for i, alpha in enumerate(np.linspace(1,0,state["steps"])):

        # walk loss path
        # modify negative slope
        for relu in state.all["relus"]:
            relu.negative_slope = alpha
        path, pow, a, r = slice_landscape(state, event, net=net, inputs=inputs, labels=labels)
        paths_lintoconv.append(path)
        pows_lintoconv.append(pow)
        a_s_lintoconv.append(a)

    # morph from convnet to resnet
    for i, nu in enumerate(np.linspace(0,1,state["steps"])):

        # walk loss path
        # modify residual amount
        state["nu"] = nu

        path, pow, a, r = slice_landscape(state, event, net=net, inputs=inputs, labels=labels)
        paths_convtores.append(path)
        pows_convtores.append(pow)
        a_s_convtores.append(a)

    paths_lintoconv = torch.stack(paths_lintoconv).cpu().numpy()
    paths_convtores = torch.stack(paths_convtores).cpu().numpy()
    a_s_lintoconv = np.stack(a_s_lintoconv)
    a_s_convtores = np.stack(a_s_convtores)
    plot_lines_3d(paths_lintoconv,a_s_lintoconv,colorlist=[(0.227,0.298,0.763),(0.859, 0.863, 0.871),(0.702,0.012,0.149)])
    plot_lines_3d(paths_convtores,a_s_convtores,colorlist=[(0.702,0.012,0.149),(0.5,0.5,0.5)])

    if state["log"]:
        df = pd.DataFrame(paths_convtores,columns=["path_"+str(i) for i in range(0,state["steps"])])
        df.index.name = 'step'
        df.to_csv("./logs/convtores.csv")

        df = pd.DataFrame(paths_lintoconv, columns=["path_" + str(i) for i in range(0, state["steps"])])
        df.index.name = 'step'
        df.to_csv("./logs/lintoconv.csv")

        df = pd.DataFrame(a_s_lintoconv, columns=["a"])
        df.index.name = 'path_no'
        df.to_csv("./logs/a_s_lintoconv.csv")

        df = pd.DataFrame(a_s_convtores, columns=["a"])
        df.index.name = 'path_no'
        df.to_csv("./logs/a_s_convtores.csv")

    # calculated 3d plot

def convnet_spectrum(state, event, net=outervar, inputs=outervar):
    '''Produces spectrum shift figures'''
    state["track_layer_outputs"] = True
    if state.all["step"] % state["step_every"] == 0:
        cache_net(state, event, net=net)
        net_cached = state["net_cached"]

        # walk one path
        with torch.no_grad():
            features_per_layer = []
            for state["current_step"] in range(state["steps"]):
                state["layer_outputs"] = []
                net_cached(inputs.detach())
                features_per_layer.append(torch.stack(state["layer_outputs"]))

                # walk a step in weight space
                walk_step(state)
        features_per_layer = torch.stack(features_per_layer,0).T
        fft = torch.rfft(features_per_layer,1)
        power = mag(fft)
        if state["rescale"]:
            power *= np.linspace(0,power.shape[-1]-1,power.shape[-1])[None,None,:]
        power /= power.norm(dim=2,keepdim=True)

        vars = power.mean(dim=0)
        freq = np.fft.rfftfreq(state["steps"], state["step_size"])
        log_freq = np.log(freq)
        colors = plt.cm.coolwarm(np.linspace(0, 1, state.all["model.convnet.depth"]))

        if state["debug_plot"]:
            for i,var in enumerate(vars):
                plt.plot(log_freq[1:], np.log(var)[1:],color=colors[i])

            plt.xlabel("log(freq)")
            plt.ylabel("log(var)")
            plt.legend(["Path depth " + str(d) for d in range(vars.shape[0])])
            plt.show()

        if state["log"]:
            df = pd.DataFrame(vars.log().T.cpu().numpy()[1:,:], columns=["log_path_"+str(d) for d in range(vars.shape[0])])
            df.index.name = "step"
            df_ = pd.DataFrame(log_freq[1:], columns=["log_freq"])
            df_.index.name = "step"
            df = df.join(df_)
            df.to_csv("./logs/spectrumshift_"+state.all["log.dir"]+".csv")


def convnet_loss_walk_per_layer(state, event, net=outervar, inputs=outervar, labels=outervar):
    '''Walks the loss surface, one layer at a time'''
    if state.all["step"] % state["step_every"] == 0:
        paths = []
        pows = []
        fig_paths = plt.figure(1)
        plt.title("paths")
        plt.xlabel("step")
        plt.ylabel("loss")
        fig_pows = plt.figure(2)
        plt.title("pows")
        plt.xlabel("freq")
        plt.ylabel("log(pow)")
        freq = np.fft.rfftfreq(state["steps"],state["step_size"])
        log_freq = np.log(freq)
        colors = plt.cm.coolwarm(np.linspace(0, 1, state.all["model.convnet.depth"]))

        for i in range(state.all["model.convnet.depth"]):
            state["layer_filter"] = "convs."+str(i)+".weight"
            path, pow, a, r = slice_landscape(state,event,net=net, inputs=inputs, labels=labels)
            pow = pow[1:]
            pow /= np.linalg.norm(pow, keepdims=True)
            paths.append(path)
            pows.append(pow)
            plt.figure(1)
            plt.plot(path.cpu(),color=colors[i])
            plt.figure(2,colors[i])
            plt.plot(np.log(pow),color=colors[i])

        plt.show()

    if state["log"]:
        df_pows = pd.DataFrame(np.log(np.stack(pows)[:,1:]).T, columns=["log_pow_dw_"+str(d) for d in range(len(pows))])
        df_pows.index.name = "step"
        df_paths = pd.DataFrame(torch.stack(paths).cpu().numpy()[:,1:].T, columns=["path_dw_"+str(d) for d in range(len(pows))])
        df_paths.index.name = "step"
        df_freqs = pd.DataFrame(log_freq[1:], columns=["log_freq"])
        df_freqs.index.name = "step"
        df = df_pows.join(df_paths).join(df_freqs)
        df.to_csv("./logs/walk_per_layer.csv")

def avg_grad(state, event,net=outervar,inputs=outervar,labels=outervar,criterion=outervar,optimizer=outervar):
    '''Measure average gradient magnitude'''
    if state.all["step"] % state["step_every"] == 0:
        norms_per_init = []
        for i in range(state["num_init"]):
            norms = []
            num_filters = []
            with torch.no_grad():
                for name, w in net.named_parameters():
                    if w.dim() != 4:
                        continue
                    if state["layer_filter"] not in name:
                        continue
                    norms.append(w.grad.data.norm())
                    num_filters.append(np.prod(w.shape))
            norms = torch.stack(norms,0).abs().sum(0)
            norms_per_init.append(norms)
            num_filters = np.sum(num_filters)
            if state["num_init"] > 1:
                net = event.init_net()
                net = event.send_net_to_device(net)
                net(inputs)
                optimizer.zero_grad()

                # get result & loss
                outputs = net(inputs)
                current_loss = torch.mean(criterion(outputs, labels))
                regularizer = sum(event.optional.regularizer(net))

                # collect all losses
                if hasattr(event, 'mix_total_loss'):
                    total_loss = event.mix_total_loss(current_loss, regularizer)
                else:
                    total_loss = current_loss + regularizer

                # optimize
                total_loss.backward()

        norms_per_init = torch.stack(norms_per_init,0).mean()
        try:
            df = pd.read_csv("./logs/gradient_stats.csv", index_col=0)
        except FileNotFoundError:
            df = pd.DataFrame()
        df = df.append(pd.Series(
            [norms_per_init.cpu().numpy(),norms_per_init.log().cpu().numpy(), (norms_per_init/num_filters).cpu().numpy(),(norms_per_init/num_filters).log().cpu().numpy()],
            name=state.all["log.dir"] + "_step_" + str(state.all["step"]),
            index=["abs_grad_norm","log_abs_grad_norm","rel_grad_norm","log_rel_grad_norm"]))
        df.to_csv("./logs/gradient_stats.csv")

def main(state,event,net=outervar,inputs=outervar,labels=outervar,criterion=outervar,optimizer=outervar):
    if state["experiment"] == "fourier_loss_walk":
        fourier_loss_walk(state,event,net,inputs,labels)
    elif state["experiment"] == "rescorr_walk":
        rescorr_walk(state,event,net,inputs)
    elif state["experiment"] == "linear_to_conv_to_resnet":
        linear_to_conv_to_resnet(state,event,net,inputs,labels)
    elif state["experiment"] == "convnet_spectrum":
        convnet_spectrum(state,event,net,inputs)
    elif state["experiment"] == "convnet_loss_walk_per_layer":
        convnet_loss_walk_per_layer(state,event,net,inputs,labels)
    elif state["experiment"] == "avg_grad":
        avg_grad(state,event,net,inputs,labels,criterion,optimizer)
    else:
        raise ValueError("Experiment unknown")

# Helper events
def residual_branch_modifier(state, event, residual_branch):
    if state["nu"] != 1:
        residual_branch *= state["nu"]

def gather_layer_output(state,event,net):
    if state["track_layer_outputs"]:
        state["layer_outputs"].append(net.flatten().cpu())

def gather_residual_summands(state,event,main,residual):
    if state["track_resnet_summands"]:
        state["resnet_summands"].append([main.flatten(), residual.flatten()])


def register(mf):
    mf.register_defaults({
        # General settings
        "filtermode": "norm", #norm/none .Norm is Filternorm (Li et al.)
        "steps": 100, # total path steps
        "plot.steps": 100,
        "step_size": 1e-2, # step size in loss walk
        "numdir": 10, # number of random directions to sample
        "step_every": 100, # execute losspath every step_every step during training
        "layer_filter":"", # only walk filters containing this string
        "path_mode": "linear", # linear/circle/linear_grad . For straight/circular/straigt in GD dir paths
        "experiment": "fourier_loss_walk",

        # Fourier loss walk settings
        "median_path": True,  # save median loss path
        "num_init": 1,  # number of reinitializastions (only makes sense when not training)
        "log": True,   # output results as csv


        # SNR settings
        "axis": 1, # dont touch
        "snrmode": "full", # multiple bins or only one bin
        "cutoff": 8, # how many values to cutoff after fft (linear paths are not periodic)
        "sample_prop": 0.1, # how many neural paths to sample
        "debug_plot": False, # instant plots?
        "sigma": 0.25, # gaussian filter width
        "num_bins": 8, # gaussian filter amount
        "corr_fourier": False, # wether to calculate correlations directly in fourier space

        # Densenet Avg Spectrum settings
        "block_to_track": 0, # which desneblock do you want to calculate the avg mag spectrum of?

        # Convnet Spectrum settings
        "rescale":True, # scale magnitude spectrum by coefficient index to simulate gradient response


    })
    mf.register_helpers({
        "w0": {},
        "dw": {},
        "criterion": None,
        "nu":1,
        "layer_outputs":[],
        "resnet_summands":[],
        "track_layer_outputs":False,
        "track_resnet_summands": False,
    })

    mf.register_event('before_training', init_new_direction)
    mf.register_event('init_new_direction', init_new_direction)
    mf.register_event('before_training', init)

    mf.register_event('before_actual_step', main)

    #helper events
    mf.register_event('residual_branch_modifier',residual_branch_modifier)
    mf.register_event('before_relu',gather_layer_output)
    mf.register_event('before_addition',gather_residual_summands)
