from moduleloader import outervar, print_info as print
import torch
import numpy as np
from scipy import optimize, stats
import matplotlib.pyplot as plt
import pandas as pd

def get_new_direction(state, event, name, w, mode="linear"):
    def random_direction(w):
        # center and scale
        new_w = torch.rand_like(w)*2-1

        # normalize
        new_w_len = (new_w**2).sum((1,2,3), keepdim=True).sqrt()
        new_w /= new_w_len

        # apply filter normalization (ref. Lee et Al. 2018)
        w_len = (w ** 2).sum((1, 2, 3), keepdim=True).sqrt()

        if state["filtermode"] == "norm":
            new_w *= w_len
        elif state["filtermode"] == "abs":
            new_w *= w.abs()

        # multiply with step_size
        # multiply with step_size
        new_w *= state["step_size"]

        return new_w

    if mode == "linear":
        return random_direction(w)
    elif mode == "linear_grad":
        if w.grad is None:
            return w*0
        else:
            new_w = -w.grad.data
            w_len = (w ** 2).sum((1, 2, 3), keepdim=True).sqrt()
            new_w_len = (new_w ** 2).sum((1, 2, 3), keepdim=True).sqrt()
            new_w /= new_w_len
            if state["filtermode"] == "norm":
                new_w *= w_len
            elif state["filtermode"] == "abs":
                new_w *= w.abs()

            # multiply with step_size
            new_w *= state["step_size"]
            return new_w
    elif mode == "circle":
        return [random_direction(w), random_direction(w)]
    else:
        return ValueError("Mode not implemented")

def walk_step(state):
    dw = state["dw"]
    net_cached = state["net_cached"]
    mode = state["path_mode"]
    with torch.no_grad():

        # before training no gradient was known, thus we repeat init_new_direction once at this point
        if mode == "linear_grad":
            if "init_new_direction_done" not in state:
                state["init_new_direction_done"] = True
                init_new_direction(state,None,net=state.all["net"])
                print("init_new_direction")

        if mode == "linear" or mode == "linear_grad":
            for name, w in net_cached.named_parameters():
                if w.dim() != 4:
                    continue
                if state["layer_filter"] and state["layer_filter"] not in name:
                    continue
                # walk step for each weight
                w.data += dw[name].data
        elif mode == "circle":
            t = state["current_step"] / state["steps"] * np.pi * 2
            for name, w in net_cached.named_parameters():
                if w.dim() != 4:
                    continue
                if state["layer_filter"] and state["layer_filter"] not in name:
                    continue

                # walk step for each weight
                # note: derivative of a circle is, again, a circle ;)
                w.data += dw[name][0].data * np.cos(t) + dw[name][1].data * np.sin(t)

# i.e. before epoch
def init_new_direction(state, event, net=outervar):
    with torch.no_grad():
        for name, w in net.named_parameters():
            if w.dim() != 4:
                continue
            state["dw"][name] = get_new_direction(state, event, name, w, mode=state["path_mode"])

# i.e. before training
def cache_net(state, event, net=outervar):
    # copy state_dict
    state["net_cached"].load_state_dict(net.state_dict())



def slice_landscape(state,event,net=outervar, inputs=outervar, labels=outervar):
    '''From the current point in weight space, take ONE slice of the loss surface in a random direction.'''
    # cache net in order to work on cached version
    cache_net(state, event, net=net)

    criterion = state["criterion"]
    net_cached = state["net_cached"]
    L = []

    # walk one path
    for state["current_step"] in range(state["steps"]):
        with torch.no_grad():
            outputs = net_cached(inputs.detach())
            loss = criterion(outputs, labels.detach())
            L.append(loss)
            walk_step(state)
    paths = torch.stack(L, axis=1)

    batch_path = torch.mean(paths, axis=0)

    a, b, r_squared, pow = fitpl_path(state, event, batch_path)
    event.optional.plot_scalar(batch_path.cpu(), list(range(state["steps"])),title="Loss surface cross-section")
    return batch_path, pow, a, r_squared


def slice_resout(state, event, net=outervar, inputs=outervar):
    '''From the current point in weight space, take ONE slices of the loss surface in a random direction.'''

    # cache net in order to work on cached version
    cache_net(state, event, net=net)
    net_cached = state["net_cached"]
    sample_prop = state["sample_prop"]


    first = True
    sample_idxs = []

    resout_per_block = []
    with torch.no_grad():
        for state["current_step"] in range(state["steps"]):
            net_cached(inputs.detach())

            resout = []
            for num, resblock in enumerate(net_cached.resblocks):

                # sample paths
                if first:
                    num_paths = np.prod(resblock.s1.shape)
                    if sample_prop < 1:
                        sample_idxs.append(np.random.choice(num_paths, size=int(np.floor(num_paths * sample_prop))))
                    else:
                        sample_idxs = [True for i in range(num_paths)]

                resout.append(resblock.out.flatten()[sample_idxs[num]])

            resout_per_block.append(resout)

            # walk a step in weight space
            walk_step(state)
            first = False


    # transpose
    resout_per_block = list(zip(*resout_per_block))

    a_per_block = []
    for block_no, paths in enumerate(resout_per_block):
        paths = torch.stack(paths, axis=-1)
        ffts = torch.rfft(paths,1)
        mean_power_spectrum = pow(ffts).cpu().numpy().mean(0)
        a, b, r_squared = fitpl_pow(state, event, mean_power_spectrum)
        a_per_block.append(a)

    return a_per_block

def rescorr(state, event, net=outervar, inputs=outervar):
    cache_net(state, event, net=net)
    net_cached = state["net_cached"]
    cutoff = state["cutoff"]
    sample_prop = state["sample_prop"]
    sigma = state["sigma"]
    num_bins = state["num_bins"]
    state["track_resnet_summands"] = True

    S1 = []
    S2 = []

    first = True
    sample_idxs = []

    # walk one path
    with torch.no_grad():
        for state["current_step"] in range(state["steps"]):
            state["resnet_summands"] = []
            net_cached(inputs.detach())

            s1 = []
            s2 = []
            for num, resblock in enumerate(state["resnet_summands"]):
                # sample paths
                if first:
                    num_paths = len(resblock[1])
                    if sample_prop < 1:
                        sample_idxs.append(np.random.choice(num_paths, size=int(np.floor(num_paths * sample_prop))))
                    else:
                        sample_idxs = [True for i in range(num_paths)]

                s1.append(resblock[0][sample_idxs[num]])
                s2.append(resblock[1][sample_idxs[num]])

            S1.append(s1)
            S2.append(s2)

            # walk a step in weight space
            walk_step(state)

            first = False

    state["track_resnet_summands"] = False
    # transpose
    S1 = list(zip(*S1))
    S2 = list(zip(*S2))

    if state["snrmode"] == "hplp":
        snri_per_block = []
        snr_per_block = []

    elif state["snrmode"] == "full":
        confusion_per_block = []
        mean_corrs_per_block = []

    # stack and flatten
    for i, (s1, s2) in enumerate(zip(S1, S2)):
        s1 = torch.stack(s1, axis=-1)
        s2 = torch.stack(s2, axis=-1)

        # fft and magnitude spectrum

        fft1 = torch.rfft(s1, 1)
        fft2 = torch.rfft(s2, 1)

        # highpass / lowpass spectral analysis
        if state["snrmode"] == "hplp":
            gaussian_lowpass = torch.Tensor(gaussian(np.linspace(0, 1, fft1.shape[-2]), 0, sigma))
            gaussian_highpass = 1 - gaussian_lowpass

            gaussian_lowpass = gaussian_lowpass.view(1, -1, 1).to(state.all["gpu"])
            gaussian_highpass = gaussian_highpass.view(1, -1, 1).to(state.all["gpu"])

            fft1_lp = (fft1 * gaussian_lowpass)
            fft1_hp = (fft1 * gaussian_highpass)
            fft2_lp = (fft2 * gaussian_lowpass)
            fft2_hp = (fft2 * gaussian_highpass)

            s1_lp = torch.irfft(fft1_lp, 1, signal_sizes=(state["steps"],))
            s1_hp = torch.irfft(fft1_hp, 1, signal_sizes=(state["steps"],))
            s2_lp = torch.irfft(fft2_lp, 1, signal_sizes=(state["steps"],))
            s2_hp = torch.irfft(fft2_hp, 1, signal_sizes=(state["steps"],))

            if cutoff > 0:
                s1 = s1[:, cutoff:-cutoff]
                s2 = s2[:, cutoff:-cutoff]
                s1_lp = s1_lp[:, cutoff:-cutoff]
                s1_hp = s1_hp[:, cutoff:-cutoff]
                s2_lp = s2_lp[:, cutoff:-cutoff]
                s2_hp = s2_hp[:, cutoff:-cutoff]

            axis = state["axis"]

            corr_lp = corr(s1_lp, s2_lp, axis=axis)
            corr_hp = corr(s1_hp, s2_hp, axis=axis)

            if state["debug_plot"]:
                s1cpu = s1[0].cpu()
                s2cpu = s2[0].cpu()
                plt.plot(s1cpu, label="First summand s1")
                plt.plot(s2cpu, label="Second summand s2")
                plt.plot((s1cpu + s2cpu) / 2, label="Resnet output (s1+s2)/2")
                plt.title("Effect of summation in ResNets")
                plt.legend(["First summand s1", "Second summand s2", "Resnet output (s1+s2)/2"])
                plt.show()

            snri = (1 + corr_lp) / (1 + corr_hp)
            snr = torch.var(s1_lp, axis=axis) / torch.var(s1_hp, axis=axis)

            snri_per_block.append(nanmean(snri))
            snr_per_block.append(nanmean(snr))

        # full spectral analysis
        elif state["snrmode"] == "full":
            if state["corr_fourier"]:
                filtered_corrs = pow(fft1 * fft2).T

            else:
                filters = gaussian_family(fft1.shape[-2], num_bins, sigma)
                filtered_corrs = []

                for filter in filters:
                    filter = filter.view(1, -1, 1).to(state.all["gpu"])
                    filtered_s1 = torch.irfft(fft1*filter, 1, signal_sizes=(state["steps"],))
                    filtered_s2 = torch.irfft(fft2*filter, 1, signal_sizes=(state["steps"],))

                    if cutoff > 0:
                        filtered_s1 = filtered_s1[:,cutoff:-cutoff]
                        filtered_s2 = filtered_s2[:,cutoff:-cutoff]

                    corr_s1s2 = corr(filtered_s1, filtered_s2, axis=state["axis"])
                    filtered_corrs.append(corr_s1s2)


                    if state["debug_plot"]:
                        plt.plot(list(range(state["steps"])), s1[0].cpu())
                        plt.plot(list(range(state["steps"])), s2[0].cpu())
                        plt.plot(list(range(cutoff,state["steps"]-cutoff)),filtered_s1[0].cpu())
                        plt.plot(list(range(cutoff,state["steps"]-cutoff)), filtered_s2[0].cpu())
                        plt.legend(["S1", "S2", "filtered S1", "filtered S2"])
                        plt.title("Filtered correlation: %.3f" %corr_s1s2[0])
                        plt.show()



                filtered_corrs = torch.stack(filtered_corrs, axis=0)
            confusion_matrix = torch.zeros(size=(len(filtered_corrs), len(filtered_corrs)))
            for (i,num) in enumerate(filtered_corrs):
                for (j,denum) in enumerate(filtered_corrs):
                    confusion_matrix[i,j] = nanmean((1+num)/(1+denum))
            confusion_per_block.append(confusion_matrix)
            mean_corrs_per_block.append(nanmean(filtered_corrs.abs(),axis=0))

        else:
            raise ValueError("snrmode not recognized")


    if state["snrmode"] == "hplp":
        snri_per_block = torch.stack(snri_per_block, axis=0)
        snr_per_block = torch.stack(snr_per_block, axis=0)
        return snri_per_block, snr_per_block

    elif state["snrmode"] == "full":
        confusion_per_block = torch.stack(confusion_per_block, axis=0)
        mean_corrs_per_block = torch.stack(mean_corrs_per_block, axis=0)
        return confusion_per_block, mean_corrs_per_block

def cov(a,b,axis):
    '''Calculates the covariance along a given axis for two tensors a and b'''
    return torch.mean(a*b, axis=axis) - (torch.mean(a, axis=axis) * torch.mean(b, axis=axis))

def corr(a,b,axis):
    '''Calculates the correlation along a given axis for two tensors a and b'''
    return cov(a,b,axis=axis) / (torch.std(a,axis=axis) * torch.std(b,axis=axis))

def mag(a):
    '''Calculates the magniude spectrum of a fft in torch tensor format'''
    return a.pow(2).sum(axis=-1).sqrt()

def pow(a):
    return a.pow(2).sum(axis=-1)

def gaussian(x, mu, sig):
    '''Returns a gaussian function'''

    return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))

def gaussian_filter(len, mu,sig):
    return gaussian(torch.linspace(0,1,len), mu,sig)

def cutoff_filter(start_idx,end_idx,len):
    '''Returns a single cutoff filter with length len and ones from start_idx to end_idx'''
    arr = torch.zeros(len)
    if start_idx == end_idx:
        arr[start_idx] = 1
    else:
        arr[start_idx:end_idx] = 1

    return arr

def cutoff_family(len, num_bins,ignore_last=True):
    '''Returns a family of cutoff filters that cover the spectrum'''
    width = int(np.floor(len/num_bins))
    filters = []
    for i in range(int(len/width)):
        filters.append(cutoff_filter(i*width,(i+1)*width, len))
    if not ignore_last and len % width != 0:
        filters.append(cutoff_filter(int(len/width)*width, len, len))
    return filters

def gaussian_family(len, num_filters,sig):
    filters = []
    centers = np.linspace(0,1,num_filters)
    for center in centers:
        filter = gaussian_filter(len, center, sig)
        filters.append(filter)
        #plt.plot(filter)
    #plt.show()
    return filters


def fitpl_path(state, event, path):
    '''Fits a powerlaw to the given magnitude spectrum of a given path

    :param float[] magnitude_: magnitude spectrum to fit
    :param int step_size: sampling frequency
    :return: returns exponent a, factor b and error R² of the powerlaw-fit
    :rtype: (float, float, float) '''

    path = path.to("cpu").numpy()
    fft = np.fft.rfft(path)
    pow = abs(fft)**2

    a, b, r_squared = fitpl_pow(state, event, pow)

    return a, b, r_squared, pow

def fitpl_pow(state, event, power):
    '''Fits a powerlaw to the given power spectrum
'''

    # exclude constant frequence
    power = power[1:]
    x = np.fft.rfftfreq(state["steps"], state["step_size"])[1:]

    f = lambda x, a, b: b * x ** -a
    event.optional.plot_scalar(power, list(range(len(power))),
                               title="Crossection power spectrum ")

    try:
        popt = optimize.curve_fit(f, x,power, maxfev=1000)[0]
        residuals = power - f(x, popt[0], popt[1])
        ss_res = np.sum(residuals ** 2)
        ss_tot = np.sum((power - np.mean(power)) ** 2)
        r_squared = 1 - (ss_res / ss_tot)
    except:
        popt = (np.NaN, np.NaN)
        r_squared = np.NaN

    return popt[0], popt[1], r_squared

def nanmean(arr, axis=0):
    if arr.dim() == 1:
        return torch.mean(nonan(arr))
    elif arr.dim() == 2:
        if axis == 0:
            out = torch.Tensor(arr.shape[0])
            for i in range(arr.shape[0]):
               out[i] = nanmean(arr[i])
            return out
        else:
            raise ValueError("axis not supported")
    else:
        raise ValueError("shape not supported")

def nonan(arr):
    return arr[~torch.isnan(arr)]

def countnan(arr):
    return torch.isnan(arr).sum()
