import os

import matplotlib.pyplot as plt
import numpy as np


def evaluate_acc_error(pred, gt):
    errors = []
    for row in range(2, gt.shape[0]):
        current_error = ((pred[:row] - gt[:row]) ** 2).mean()
        errors.append(current_error)
    return errors


def compute_l2_error(x, y, reduce=True):
    if reduce:
        return ((x - y) ** 2).mean()
    else:
        return ((x - y) ** 2).mean(tuple(np.arange(1, len(x.shape))))


def show_compared_params(pred_params, gt_params, title, show=False):
    params_num = pred_params.shape[1]

    x_axis = np.arange(0.25, 20.0, 0.5)
    if params_num == 1:
        fig = plt.figure()
        plt.title(title)
        for i in range(params_num):
            plt.title(f'Parameter {i}')
            plt.plot(x_axis, pred_params[:, i], label='Pred')
            plt.plot(x_axis, gt_params[:, i], label='GT')
            plt.xlabel('x')

    else:
        fig, axs = plt.subplots(1, params_num)
        plt.suptitle(title)
        for i in range(params_num):
            axs[i].set_title(f'Parameter {i}')
            axs[i].plot(x_axis, pred_params[:, i], label='Pred')
            axs[i].plot(x_axis, gt_params[:, i], label='GT')
            axs[i].set_xlabel('x')

    if show:
        plt.show()
    return fig


def show_pred_sol(pred_sol, gt_sol, title, show=False, save_path=None):
    fig, axs = plt.subplots(1, 2)
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    plt.suptitle(title)

    axs[0].imshow(pred_sol, aspect='auto', origin='lower')
    axs[0].set_title('Predicted')
    axs[0].set_ylabel('Time')
    im = axs[1].imshow(gt_sol, aspect='auto', origin='lower')
    axs[1].set_title('GT')
    axs[1].set_yticks([])
    plt.colorbar(im, cax=cbar_ax)

    if save_path is not None:
        plt.savefig(save_path)
    if show:
        plt.show()

    plt.close()
