import math

import torch

import numpy as np
import matplotlib.pyplot as plt

from resonance import plotting


class AR2MeanGaussian:

    def __init__(self, stationary_variance, peak_frequency, sampling_variance, noise_variance):
        assert stationary_variance > sampling_variance

        self.stationary_variance = stationary_variance
        self.peak_frequency = peak_frequency
        self.sampling_variance = sampling_variance

        self.mean_signal_stationary_variance = stationary_variance - sampling_variance

        self.X_t = 0
        self.X_tm1 = 0
        self.X_tm2 = 0

        self.noise_variance = noise_variance

        self.phi_2 = self.compute_phi_2(self.mean_signal_stationary_variance, peak_frequency, noise_variance)
        self.phi_1 = self.compute_phi_1(self.phi_2, self.peak_frequency)

        self.ensure_process_validity()

        self.current_time = 0


    def mean(self):
        return self.X_t

    def advance_time(self):
        self.X_t = self.phi_1 * self.X_tm1 + self.phi_2 * self.X_tm2 + np.random.normal(0, np.sqrt(self.noise_variance))

        self.X_tm2 = self.X_tm1
        self.X_tm1 = self.X_t

        self.current_time += 1

    def batch(self, batch_size, num_runs):
        return torch.normal(self.mean(), math.sqrt(self.sampling_variance), (batch_size, num_runs))

    def total_distribution_sample(self, batch_size, num_runs):
        # Assume a normal distribution for the stationary.  (TODO: As yet unverified.)
        return torch.normal(0, math.sqrt(self.stationary_variance), (batch_size, num_runs))

    def compute_phi_1(self, phi_2, peak_frequency):
        return 4 * phi_2 / (phi_2 - 1) * math.cos(2 * math.pi * peak_frequency)

    def compute_phi_2(self, mean_signal_stationary_variance, peak_frequency, noise_variance):
        # As per wolfram alpha:
        #
        # - x^5 +   3 x^4 +   (16 c - 2 - s) x^3 +   (16 c - 2 + 3 s) x^2 +    (3 - 3 s) x   + s - 1
        #
        # where  c = cos(2 pi f)^2  s = sigma_eps^2 / v

        c = np.cos(2 * np.pi * peak_frequency) ** 2
        s = noise_variance / mean_signal_stationary_variance

        poly_coefficients = [
            -1,
            3,
            16 * c - 2  - s,
            16 * c - 2  + 3 * s,
            3 - 3 * s,
            s - 1,
        ]

        real_roots = [r.real for r in np.roots(poly_coefficients) if r.imag == 0]

        return np.sort(real_roots)[0]

    def ensure_process_validity(self):
        # Stationarity conditions
        assert np.abs(self.phi_2) < 1
        assert self.phi_2 + self.phi_1 < 1
        assert self.phi_2 - self.phi_1 < 1

        # Mid peak frequency conditions
        assert self.phi_1 ** 2 + 4 * self.phi_2 < 0


class NoisyLinearTarget:

    def __init__(self, target_function, observation_variance):
        self.target_function = target_function
        self.observation_variance = observation_variance

    def batch(self, domain_batch):
        with torch.no_grad():
            target_sample = self.target_function(domain_batch)

        if self.observation_variance > 0:
            return target_sample + torch.normal(0, math.sqrt(self.observation_variance), domain_batch.shape)
        else:
            return target_sample



def generate_domain_samples(samplers, batch_size, runs_per_frequency):
  """
  Generates a matrix with shape (batch_size, runs_per_frequency * len(frequencies)) with frequencies grouped.
  i.e. each row is a batch, each column is a run.

  XXX: THIS WILL MUTATE THE STATE OF THE SAMPLER BECAUSE LEGACY CODE IS BAD.
  """
  domain_samples = []
  for sampler in samplers:
    domain_samples.append(sampler.batch(batch_size, runs_per_frequency))
    sampler.advance_time()
  return torch.cat(domain_samples, dim=1)

def generate_test_domain_samples(samplers, batch_size, runs_per_frequency):
  """
  Similar to generate_domain_samples, but gives a sample of the stationary distribution.
  This is a highly redundant matrix (all columns identical), but exists to give vectorized test loss computation
  across all parallel models.
  """
  domain_samples = []
  for sampler in samplers:
    domain_samples.append(sampler.total_distribution_sample(batch_size, runs_per_frequency))
  return torch.cat(domain_samples, dim=1)


def test_generate_domain_samples():
    frequencies = [1/4, 1/2]
    batch_size = 10
    runs_per_frequency = 3
    amplitude = 1
    variance = 1e-100

    samplers = [SinusoidalMeanGaussian(amplitude, frequency, variance) for frequency in frequencies]

    samples_0 = generate_domain_samples(samplers, batch_size, runs_per_frequency)
    samples_1 = generate_domain_samples(samplers, batch_size, runs_per_frequency)
    samples_2 = generate_domain_samples(samplers, batch_size, runs_per_frequency)
    samples_3 = generate_domain_samples(samplers, batch_size, runs_per_frequency)
    samples_4 = generate_domain_samples(samplers, batch_size, runs_per_frequency)

    assert samples_0.shape == (batch_size, runs_per_frequency * len(frequencies))

    # With 1e-10 delta distribution at time \in {0, 2, 4}, everything should be close to zero.
    assert torch.allclose(samples_0, torch.zeros_like(samples_0))
    assert torch.allclose(samples_2, torch.zeros_like(samples_2))
    assert torch.allclose(samples_4, torch.zeros_like(samples_4))

    # time = 1 should induce first frequency at amplitude, second at zero
    assert torch.allclose(samples_1[:, 0:runs_per_frequency], amplitude * torch.ones((batch_size, runs_per_frequency)))
    assert torch.allclose(samples_1[:, runs_per_frequency:], torch.zeros((batch_size, runs_per_frequency)))

    # time = 3 should induce first frequency at -amplitude, second at zero
    assert torch.allclose(samples_3[:, 0:runs_per_frequency], -amplitude * torch.ones((batch_size, runs_per_frequency)))
    assert torch.allclose(samples_3[:, runs_per_frequency:], torch.zeros((batch_size, runs_per_frequency)))


def test_generate_test_domain_samples():
    frequencies = [0.1, 0.2]
    batch_size = 10
    runs_per_frequency = 3
    amplitude = 1
    variance = 1e-100

    samplers = [SinusoidalMeanGaussian(amplitude, frequency, variance) for frequency in frequencies]
    samples = generate_test_domain_samples(samplers, batch_size, runs_per_frequency)

    assert samples.shape == (batch_size, runs_per_frequency * len(frequencies))

    # Ensure that samples drawn with very tight variance have been transformed by the sinusoid mean...
    assert torch.any(samples > 0.5)
    assert torch.any(samples < -0.5)

    # ... but not by more than the amplitude.
    assert torch.any(samples < 1.5)
    assert torch.any(samples > -1.5)


def test_ar2_process():
    stationary_variance = 1.1
    peak_frequency = 1 / 100
    sampling_variance = 1.0
    noise_variance = 0.00001

    sampler = AR2MeanGaussian(stationary_variance, peak_frequency, sampling_variance, noise_variance)
    means = []
    samples = []
    times = [x for x in range(2000)]
    for _t in times:
        means.append(sampler.mean())
        samples.append(sampler.batch(1, 1).flatten()[0].item())
        sampler.advance_time()
    plt.scatter(times, samples, alpha=0.3)
    plt.plot(means, color='black', linestyle='--', linewidth=3)
    plt.show()

    print(f'Target mean signal variance: {stationary_variance - sampling_variance:.3f}, Actual: {np.var(means):.3f}')
    print(f'Target overall variance: {stationary_variance:.3f}, Actual: {np.var(samples):.3f}')

    plotting.plot_spectra(means, samples, peak_frequency, 'AR(2)')


def plot_generate_domain_samples(num_frequencies, num_runs_per_frequency, batch_size):
    frequencies = torch.linspace(0, 0.1, num_frequencies)
    times = []
    samples = []
    samplers = [SinusoidalMeanGaussian(1, frequency, 1) for frequency in frequencies]
    for time in range(50):
        samples.append(generate_domain_samples(time, samplers, batch_size, num_runs_per_frequency))
        times.append(time)

    samples = torch.stack(samples)
    for run in range(num_runs_per_frequency * num_frequencies):
        for batch_idx in range(batch_size):
            plt.plot(times, samples[:, batch_idx, run], marker=',', markersize=1, c='k', linewidth=0)
        plt.show()

if __name__ == '__main__':
    # plot_generate_domain_samples(3, 2, 100)
    test_ar2_process()
    test_generate_domain_samples()
    test_generate_test_domain_samples()
