import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

import gd_baseline
import train_lsa
import noisy_lr_data_generation

if __name__ == "__main__":
    # 1. Train LSA layer
    train_config = train_lsa.icl_gd_config()
    lsa_layer, train_config = train_lsa.lsa_training_loop(train_config)

    # 2. Get GD baseline
    predicted_lr = gd_baseline.theoretical_lr(train_config, sample_size=100000)
    gd_model = gd_baseline.GDBaseline(predicted_lr)

    # 3. Get a list of scales and the errors between GD
    # and LSA layer, and also the loss of LSA_layer
    scale_list = [1, 2, 4, 8, 16, 32, 64, 128, 200, 300, 400, 500, 600, 700, 800]
    x_ticks = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    lsa_gd_differences = []
    lsa_errors = []
    for scale in scale_list:
        # 1. Get xs and ys
        num_sequences = 500
        task = noisy_lr_data_generation.NoisyLinearRegressionTask(dimension=train_config.dimension, output_variance=train_config.output_variance,
                                                                  batch_size=num_sequences, device=device)
        xs = task.sample_xs(train_config.sequence_length) # (batch_size, num_points, dimension)
        xs = scale * xs
        context, last_y, ys = task.evaluate(xs)

        # 2. Get predictions of LSA layer
        lsa_predictions = lsa_layer(context)[:, -1, -1]

        # 3. Get predictions of GD baseline
        x_context = xs[:, :-1, :]
        y_context = ys[:, :-1, :]
        x_query = xs[:, -1, :]
        gd_result = gd_model.evaluate(x_context, y_context, x_query) # (batch_size,)

        # 4. Get metrics
        lsa_gd_differences.append(F.mse_loss(lsa_predictions, gd_result).item())
        lsa_errors.append(F.mse_loss(lsa_predictions, last_y).item())
    
    # Measure error
    print("MSE between LSA and GD:", lsa_gd_differences)
    print("MSE between LSA and Ground Truth:", lsa_errors)

    # 4. Plot the OOD behavior of LSA layer.
    plt.plot(lsa_gd_differences, label="LSA and GD Diff")
    plt.plot(lsa_errors, label="LSA and Target Y Diff")
    plt.xticks(x_ticks, scale_list)
    plt.legend()
    plt.xlabel("Scaling of X")
    plt.ylabel("MSE")
    plt.title("LSA and 1-step GD Difference")
    plt.savefig("lsa_gd_difference_plot.png")
    plt.show()
    plt.clf()

    # 5. Plot how the theoretical learning rate dependss on sigma.
    variances = [0.5, 1, 1.5, 2]
    predicted_lr_list = []
    for var in variances:
        train_config.output_variance = var
        predicted_lr = gd_baseline.theoretical_lr(train_config=train_config, sample_size=100000)
        predicted_lr_list.append(predicted_lr)
    plt.plot(predicted_lr_list, label="Predicted LR")
    x_ticks = [0, 1, 2, 3]
    plt.xticks(x_ticks, variances)
    plt.legend()
    plt.xlabel("Output variance $\sigma$")
    plt.title("Theoretical Prediction for Learning Rate using $\sigma$")
    plt.savefig("lr_sigma_plot.png")
    plt.show()
    plt.clf()