import torch
import matplotlib.pyplot as plt

import difficulty
from training import train_measure_segments


def sweep_difficulties(model_optimizer_init, training_iterations, num_means, difficulties, bound, test_size, measurement_interval, test_grid=False, batch_size_per_distribution=1):
    segment_length = training_iterations // num_means

    results = []
    for diff in difficulties:
        distributions = difficulty.get_distributions(num_means, segment_length, diff, bound)
        trajectory = [dist(1) for dist in distributions]
        plt.plot(trajectory)
        plt.ylim((-1, 1))
        plt.title(f'domain trajectory, d={diff}')
        plt.show()

        model, optimizer = model_optimizer_init()

        steps, segment_test_losses, segment_train_losses = train_measure_segments(
            model,
            optimizer,
            distributions,
            measurement_interval,
            test_size,
            test_grid,
            batch_size_per_distribution,
        )
        results.append((steps, segment_test_losses, segment_train_losses))

        plot_log(f'difficulty {diff}:', steps, segment_test_losses, segment_train_losses, 0.2)

    return results


def sweep_learning_rates(model_init, training_iterations, num_means, diff, bound, test_size, learning_rates, measurement_interval, test_grid=False, final_intervals_to_average=5):
    segment_length = training_iterations // num_means
    distributions = difficulty.get_distributions(num_means, segment_length, diff, bound)

    means = []
    vars = []
    for learning_rate in learning_rates:
        model = model_init()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

        steps, segment_test_losses, segment_train_losses = train_measure_segments(
            model,
            optimizer,
            distributions,
            measurement_interval,
            test_size,
            test_grid,
        )

        mean, var = plot_log(f'learning rate {learning_rate}:', steps, segment_test_losses, segment_train_losses, 1, final_intervals_to_average=5)
        means.append(mean)
        vars.append(var)

    return means, vars


def plot_log(title_string, steps, segment_test_losses, segment_train_losses, ymax, final_intervals_to_average=5):
    mean = torch.mean(torch.clone(segment_test_losses[-final_intervals_to_average:]))
    var = torch.var(torch.clone(segment_test_losses[-final_intervals_to_average:]))

    print(title_string)
    print(f'mean test loss, final {final_intervals_to_average} intervals:     {mean:.6f}')
    print(f'test loss variance, final {final_intervals_to_average} intervals: {var:.6f}')
    print('')
    plt.plot(steps, segment_train_losses, label='train')
    plt.plot(steps, segment_test_losses, label='test')
    plt.legend()
    plt.title(title_string)
    plt.ylim((0, ymax))
    plt.show()

    return mean, var
