import os

import numpy as np

from scipy.integrate import solve_ivp

from tqdm import tqdm

def monodromy_spectral_norm(cov_shift_amplitude, cov_shift_frequency, mu, eta):
    '''
    returns the largest eigenvalue's magnitude for the monodromy matrix.

    numerical integration runs as 1/frequency, so expect high cost and bad
    behaviour as frequency approaches zero.
    '''
    
    # If covariate shift frequency is zero, the system is time invariant.  i.e. periodic 
    # for any period T.  So we set it arbitrarily to one to avoid dividing by zero.
    
    if cov_shift_frequency == 0:
        period = 1
    else:
        period = 1 / cov_shift_frequency
    
    flow_time_scaled_period = period * np.sqrt(eta)
    
    def instantaneous_mean(t, cov_shift_amplitude, cov_shift_frequency, eta):
        if cov_shift_frequency == 0.0:
            # If frequency is zero, period is infinite, so return sin(0) which is 0
            return 0
        else:
            return cov_shift_amplitude * np.sin(2 * np.pi * t / flow_time_scaled_period)

    def derivative(t, y):
        x_bar = instantaneous_mean(t, cov_shift_amplitude, cov_shift_frequency, eta)
        alpha = (mu - 1) / np.sqrt(eta)

        return np.array([
            [0, 0, 1, 0],
            [0, 0, 0, 1],
            [-2 * (1 + x_bar ** 2), -2 * x_bar, alpha, 0],
            [-2 * x_bar, -2, 0, alpha],
        ]) @ y
  
    t_span = [0, flow_time_scaled_period]

    monodromy_columns = []
    for init_conditions in np.eye(4):
        sol = solve_ivp(derivative, t_span, init_conditions)
        final_state = sol.y[:, -1]
        monodromy_columns.append(final_state)
  
    monodromy_matrix_transpose = np.vstack(monodromy_columns)
    return np.max(np.abs(np.linalg.eigvals(monodromy_matrix_transpose)))


def spectral_norm_surface(momenta, frequencies, amplitude, step_size):
    heatmap = np.zeros((len(momenta), len(frequencies)))
    for i, mu in enumerate(tqdm(momenta)):
        for j, frequency in enumerate(frequencies):
            heatmap[i, j] = monodromy_spectral_norm(amplitude, frequency, mu, step_size)
    
    return heatmap

def save_theoretical_results(learning_rates, momenta, frequencies, amplitude, results_dir):
    if not os.path.isdir(results_dir):
        os.makedirs(results_dir)
        
    for eta in learning_rates:
        theory_results_path = os.path.join(results_dir,  f'spectral_norm_surface_{eta}.npy')
        if not os.path.isfile(theory_results_path):
            theory_result = spectral_norm_surface(momenta, frequencies, amplitude, eta)
            with open(theory_results_path, 'wb') as f:
                np.save(f, theory_result)

def load_theoretical_results(amplitude, momenta, frequencies, learning_rate):
    # Ideally this function would have a symmetrical signature with save_theoretical_results. The
    # save_theoretical results would construct its own path from sweep parameters, and would be
    # specific to a single learning rate.
    #
    # But that's not yet the world we've constructed for ourselves.
    #
    # Sometimes productivity awaits beauty. But here beauty awaits productivity.
    
    theoretical_results_dir = unique_sweep_id(amplitude, momenta, frequencies)
    theory_results_path = os.path.join(theoretical_results_dir, f'spectral_norm_surface_{learning_rate}.npy')
    with open(theory_results_path, 'rb') as f:
        spectral_norm_surface = np.load(f)
    return spectral_norm_surface

                
def unique_sweep_id(amplitude, momenta, frequencies):
    return f'results/theoretical/a={amplitude},m=({momenta[0]},{momenta[-1]},{len(momenta)}),f=({frequencies[0]},{frequencies[-1]},{len(frequencies)})'

if __name__ == '__main__':
    cov_shift_amplitude = 0.5
    cov_shift_frequency = 0.0135
    mu = 0.99
    eta = 0.001
    
    spectral_norm = monodromy_spectral_norm(cov_shift_amplitude, cov_shift_frequency, mu, eta)
    assert np.abs(spectral_norm - 0.9572) < 1e-3
    
    start_momentum = mu
    end_momentum = 0.95
    num_momenta = 3
    momenta = np.linspace(start_momentum, end_momentum, num_momenta)

    start_frequency = cov_shift_frequency
    end_frequency = 0.05
    num_frequency = 4
    frequencies = np.linspace(start_frequency, end_frequency, num_frequency)

    assert np.abs(spectral_norm_surface(momenta, frequencies, cov_shift_amplitude, eta)[0, 0] - 0.9572) < 1e-3