import os

import torch

from tqdm import tqdm

from resonance import training, sampling, plotting

def frequency_sweep(
        make_sampler,
        frequencies,
        runs_per_frequency,

        # target function sampling
        observation_noise_variance,

        # training parameters
        training_iterations,
        batch_size,
        learning_rate,
        momentum,

        # measurement parameters
        evaluation_period,

        # initialization parameters
        init_at_target,
):
    """
    NOTES:
    - Keep frequencies below 0.1 (10 samples per period) unless you want to start fighting aliasing problems.
    """
    with torch.no_grad():

        target_functions = training.ParallelLinear(len(frequencies) * runs_per_frequency)
        learned_functions = training.ParallelLinear(len(frequencies) * runs_per_frequency)

        if init_at_target:
            raise NotImplementedError()

        iterations = []
        target_distances = []
        samplers = [make_sampler(frequency) for frequency in frequencies]
        
        prev_domain_samples = sampling.generate_domain_samples(samplers, batch_size, runs_per_frequency)
        for time in tqdm(range(training_iterations)):
            # XXX: Note that samplers are now stateful, with integer time advancing every time
            # generate_domain_samples is called.  Gross, but it's the way for now :(
            
            domain_samples = sampling.generate_domain_samples(samplers, batch_size, runs_per_frequency)
            
            # GET MARKOV STATE PER OBSERVATION, ASSUMING 2-MARKOVIANNESS
            markov_state = torch.vstack([
                prev_domain_samples,
                domain_samples,
            ])

            targets = sampling.NoisyLinearTarget(target_functions, observation_noise_variance).batch(markov_state)

            learned_functions.update(learning_rate, momentum, markov_state, targets)
            
            prev_domain_samples = domain_samples

            if time % evaluation_period == 0:
                iterations.append(time)
                
                # Used to use vectorized_loss, but we decided to use weight space distance instead.
                target_dist = training.vectorized_distance_to_target(target_functions, learned_functions)
                target_distances.append(target_dist)

        iterations = torch.tensor(iterations)
        target_distances = torch.stack(target_distances)
        return iterations, target_distances


def save_sweep(
    frequencies,
    learning_rates,
    momenta,

    runs_per_frequency,
    training_iterations,

    results_dir,

    make_sampler,

    # target function sampling
    observation_noise_variance,

    # training parameters
    batch_size,

    # measurement parameters
    evaluation_period,

    init_at_target,

    # Only used for intermediate results plotting, not actual data recording.
    tail_fraction,
):
  if not os.path.exists(results_dir):
    os.makedirs(results_dir)

  for learning_rate in learning_rates:
    for momentum in momenta:
      results_path = os.path.join(results_dir, f'lr={learning_rate} mu={momentum}.pt')
      if os.path.exists(results_path):
        continue

      iterations, distances = frequency_sweep(
        make_sampler,
        frequencies,
        runs_per_frequency,
        observation_noise_variance,
        training_iterations,
        batch_size,
        learning_rate,
        momentum,
        evaluation_period,
        init_at_target,
      )
      plotting.plot_frequency_sweep_mean_distance(frequencies, runs_per_frequency, distances, f'lr={learning_rate} mu={momentum}', tail_fraction)

      torch.save(distances, results_path)
      torch.save(iterations, os.path.join(results_dir, 'iterations.pt'))  # Overwritten, but always the same


if __name__ == '__main__':
    results_dir = 'results/test'
    tail_fraction = 0.5

    frequencies = torch.linspace(0.0, 0.03, 3)
    learning_rates = [thousandths / 1000 for thousandths in range(1, 2)]
    momenta = [thousandths / 1000 for thousandths in range(999, 996, -1)]

    runs_per_frequency = 5

    save_sweep(
        frequencies=frequencies,
        learning_rates=learning_rates,
        momenta=momenta,

        runs_per_frequency=runs_per_frequency,
        training_iterations=8000,

        results_dir=results_dir,

        # sinusoidal mean gaussian domain sampling (frequency swept in a loop below)
        variance=0.5,
        amplitude=0.2,

        # target function sampling
        observation_noise_variance=0.0,

        # training parameters
        batch_size=10,

        # measurement parameters
        evaluation_period=100,

        init_at_target=False,

        tail_fraction=tail_fraction,
    )