"""
Implementation of Frequency domain filter based on

https://arxiv.org/pdf/1901.06523.pdf
(see also http://old.ins.sjtu.edu.cn/files/paper/20190712160920_F-Principle190703SJTU_ZhiqinXu.pdf)

The idea here is to decompose the output our network into the part that was generated by various frequency bands.
The hypothesis is that noisy, high frequency, parts of one learning signal, are low frequncy parts of another, and
that therefore these tasks should be learned seperately.

I will test this theory by evaluating the contribution from 5 different bands for either the value, or the policy.
If these (normalized) frequency bands match I will reject my hypothesis.
"""

import torch


@torch.no_grad()
def compute_band_pass(x: torch.Tensor, y: torch.Tensor, delta):
    """
    @param X: the n inputs to process, of dims [N, D_x]
    @param y: the n targets of dims [N, D_y]
    @param delta: bandwidth, or alternatively a list of bandwidths.
    @output: ...
    """

    x = x.to(dtype=torch.float32)

    N = len(x)
    assert len(y) == N

    if type(delta) in [float, int]:
        delta_list = [delta]
        scalar_result = True
    else:
        delta_list = delta
        scalar_result = False

    # cdist wants BPM, we set B to 1, and ravel down to 1D
    x = x.reshape([1, N, -1])

    # I seem to get errors if I use mm, this is probably a bit slower... maybe I should increase to double?
    sqr_distances = (torch.cdist(x, x, p=2, compute_mode='donot_use_mm_for_euclid_dist') ** 2)[0, :, :]
    results = []

    def compute_result(d):
        G_d = torch.exp(-sqr_distances / (2 * d))  # [N, N]
        C = G_d.sum(dim=0)  # [N]
        sums = y[:, None] * G_d  # [N, N, D_y]
        return (1 / C) * torch.sum(sums, dim=0)  # [N, D_y]

    for delta in delta_list:
        results.append(compute_result(delta))

    return results[0] if scalar_result else results






