import matplotlib.pyplot as plt
import numpy as np
import json

def plot_pretrain_loss(epoch_loss, task_name):
    plt.plot(epoch_loss, label=f"InfoNCE loss")
    plt.legend()
    plt.savefig(f"output/{task_name}_pretrain_loss.png")
    plt.clf()
    loss_info = {}
    n_epochs = len(epoch_loss)
    for i in range(n_epochs):
        loss_info[f"epoch={i}/{n_epochs}"] = {"InfoNCE": epoch_loss[i]}
    with open(f"output/{task_name}_pretrain_loss.json", "w") as f:
        json.dump(loss_info, f, indent=4)

def plot_train_loss(epoch_loss, task_name):
    plt.plot(epoch_loss["Denoising"], label=f"Denoising loss", color="blue")
    plt.plot(epoch_loss["Contrastive"], label=f"Contrastive loss", color="green")
    plt.plot(epoch_loss["Total"], label=f"Total loss", color="red")
    plt.legend()
    plt.savefig(f"output/{task_name}_training_loss.png")
    plt.clf()
    loss_info = {}
    n_epochs = len(epoch_loss["Total"])
    for i in range(n_epochs):
        loss_info[f"epoch={i}/{n_epochs}"] = {"Total": epoch_loss["Total"][i],
                                              "Denoising": epoch_loss["Denoising"][i],
                                              "Contrastive": epoch_loss["Contrastive"][i]}
    with open(f"output/{task_name}_training_loss.json", "w") as f:
        json.dump(loss_info, f, indent=4)

def plot_predicted_results(test_no, result_data, task_name, dim_no):
    pred = result_data["Predictions"][:, :, dim_no]
    truth = result_data["Ground truth"][:, dim_no]
    for i in range(pred.shape[0]):
        plt.plot(pred[i], color="blue")
    plt.plot(truth, color="red")
    plt.savefig(f"output/figures/{task_name}_{test_no}_{dim_no}.png")
    plt.clf()
    print(f"{test_no} evaluated done")