import matplotlib.pyplot as plt
import numpy as np
import os
from configs.config import SAVE, LEFT_LIMIT, RIGHT_LIMIT, SUPERVISED

def plot_all_results(
    losses_runs_dict, metrics_runs_dict,
    test_inputs_list_dict, test_outputs_list_dict,
    test_predictions_list_dict, test_prediction_before_projection_list_dict,
    output_dir, functions, train_inputs, train_outputs,
    test_inputs, test_outputs
):
    PLOT = True  # Ensure that plotting is enabled
    EPOCHS = len(losses_runs_dict[next(iter(losses_runs_dict))][0])
    epochs_range = np.arange(1, EPOCHS+1)

    modes = list(losses_runs_dict.keys())
    num_functions = len(functions)

    # Plot average losses with error bars including small horizontal delimiters
    if PLOT:
        # Plot Losses of both constrained and unconstrained together
        losses_plot = plt.figure(figsize=(12, 6))

        loss_data_plot = losses_plot.add_subplot(121)

        # Handle modes accordingly
        for mode in modes:
            if mode == 'constrained':
                # Constrained case: plot before and after projection
                losses_runs = losses_runs_dict[mode]
                # Convert losses to numpy arrays
                loss_data_array = np.array([
                    [
                        loss['loss_data_after_projection'] 
                        for loss in losses
                    ] 
                    for losses in losses_runs
                ])  # Shape (N, epochs)
                loss_displacement_array = np.array([
                    [
                        loss['loss_displacement'] 
                        for loss in losses
                        ] 
                        for losses in losses_runs
                ])  # Shape (N, epochs)
                loss_before_projection_array = np.array([
                    [
                        loss['loss_data_before_projection'] 
                        for loss in losses
                    ] 
                    for losses in losses_runs
                ])  # Shape (N, epochs)

                # Compute mean and std over runs
                mean_loss_data = np.mean(loss_data_array, axis=0)
                std_loss_data = np.std(loss_data_array, axis=0)

                mean_loss_displacement = np.mean(
                    loss_displacement_array, axis=0)
                std_loss_displacement = np.std(
                    loss_displacement_array, axis=0)

                mean_loss_before_projection = np.mean(
                    loss_before_projection_array, axis=0)
                std_loss_before_projection = np.std(
                    loss_before_projection_array, axis=0)

                # Plot after projection
                loss_data_plot.plot(
                    epochs_range,
                    mean_loss_data,
                    color='blue',
                    linestyle='-',
                    label='Constrained After Projection'
                )
                loss_data_plot.fill_between(
                    epochs_range,
                    mean_loss_data - std_loss_data,
                    mean_loss_data + std_loss_data,
                    color='blue',
                    alpha=0.1
                )
                # Plot before projection
                loss_data_plot.plot(
                    epochs_range,
                    mean_loss_before_projection,
                    color='blue',
                    linestyle='--',
                    label='Constrained Before Projection'
                )
                loss_data_plot.fill_between(
                    epochs_range,
                    mean_loss_before_projection - std_loss_before_projection,
                    mean_loss_before_projection + std_loss_before_projection,
                    color='blue',
                    alpha=0.1
                )
            elif mode == 'unconstrained':
                # Unconstrained case: only one loss
                losses_runs = losses_runs_dict[mode]
                loss_unconstrained_array = np.array([[loss['loss_unconstrained'] for loss in losses] for losses in losses_runs])  # Shape (N, epochs)
                mean_loss_unconstrained = np.mean(loss_unconstrained_array, axis=0)
                std_loss_unconstrained = np.std(loss_unconstrained_array, axis=0)
                loss_data_plot.plot(
                    epochs_range,
                    mean_loss_unconstrained,
                    color='green',
                    linestyle='-',
                    label='Unconstrained'
                )
                loss_data_plot.fill_between(
                    epochs_range,
                    mean_loss_unconstrained - std_loss_unconstrained,
                    mean_loss_unconstrained + std_loss_unconstrained,
                    color='green',
                    alpha=0.1
                )

        loss_data_plot.set_title("Loss Data" if SUPERVISED else "Objective Value")
        loss_data_plot.set_xlabel("Epoch")
        loss_data_plot.set_ylabel("Loss") if SUPERVISED else loss_data_plot.set_ylabel("Objective Value")
        loss_data_plot.legend()

        # Second Subplot: Projection Displacement with Shaded Standard Deviation
        proj_displacement = losses_plot.add_subplot(122)
        if 'constrained' in modes:
            mean_loss_displacement = np.mean(loss_displacement_array, axis=0)
            std_loss_displacement = np.std(loss_displacement_array, axis=0)
            proj_displacement.plot(
                epochs_range,
                mean_loss_displacement,
                color='red',
                linestyle='-',
                label='Constrained Projection Displacement'
            )
            proj_displacement.fill_between(
                epochs_range,
                mean_loss_displacement - std_loss_displacement,
                mean_loss_displacement + std_loss_displacement,
                color='red',
                alpha=0.1
            )
            proj_displacement.set_title("Projection Displacement")
            proj_displacement.set_xlabel("Epoch")
            proj_displacement.set_ylabel("Loss") 
            proj_displacement.legend()
        else:
            proj_displacement.axis('off')

        # Save the plot
        plt.tight_layout()
        if SAVE:
            plt.savefig(os.path.join(output_dir, "plots", "losses_comparison.png"))
        # plt.close()

        ##
        # Plot Losses in log scale
        losses_plot_log = plt.figure(figsize=(12, 6))

        # First Subplot: Loss Data with Shaded Standard Deviation
        loss_data_plot = losses_plot_log.add_subplot(121)

        # Handle modes accordingly
        for mode in modes:
            if mode == 'constrained':
                if not SUPERVISED:
                    mean_loss_data = np.arcsinh(mean_loss_data)
                    mean_loss_before_projection = np.arcsinh(mean_loss_before_projection)
                    std_loss_data = np.arcsinh(std_loss_data)
                    std_loss_before_projection = np.arcsinh(std_loss_before_projection)
                # Constrained case: plot before and after projection
                # Plot after projection
                loss_data_plot.plot(
                    epochs_range,
                    mean_loss_data,
                    color='blue',
                    linestyle='-',
                    label='Constrained After Projection'
                )
                loss_data_plot.fill_between(
                    epochs_range,
                    mean_loss_data - std_loss_data,
                    mean_loss_data + std_loss_data,
                    color='blue',
                    alpha=0.1
                )
                # Plot before projection
                loss_data_plot.plot(
                    epochs_range,
                    mean_loss_before_projection,
                    color='blue',
                    linestyle='--',
                    label='Constrained Before Projection'
                )
                loss_data_plot.fill_between(
                    epochs_range,
                    mean_loss_before_projection - std_loss_before_projection,
                    mean_loss_before_projection + std_loss_before_projection,
                    color='blue',
                    alpha=0.1
                )
            elif mode == 'unconstrained':
                if not SUPERVISED:
                    mean_loss_unconstrained = np.arcsinh(mean_loss_unconstrained)
                    std_loss_unconstrained = np.arcsinh(std_loss_unconstrained)
                # Unconstrained case: only one loss
                loss_data_plot.plot(
                    epochs_range,
                    mean_loss_unconstrained,
                    color='green',
                    linestyle='-',
                    label='Unconstrained'
                )
                loss_data_plot.fill_between(
                    epochs_range,
                    mean_loss_unconstrained - std_loss_unconstrained,
                    mean_loss_unconstrained + std_loss_unconstrained,
                    color='green',
                    alpha=0.1
                )

        if SUPERVISED:
            loss_data_plot.set_yscale("log") # Set y-axis to log scale
        loss_data_plot.set_title("Loss Data" if SUPERVISED else "Objective Value")
        loss_data_plot.set_xlabel("Epoch")
        loss_data_plot.set_ylabel("Loss") if SUPERVISED else loss_data_plot.set_ylabel("arcsinh(Objective Value)")
        loss_data_plot.legend()

        # Second Subplot: Projection Displacement with Shaded Standard Deviation
        proj_displacement = losses_plot_log.add_subplot(122)
        if 'constrained' in modes:
            proj_displacement.plot(
                epochs_range,
                mean_loss_displacement,
                color='red',
                linestyle='-',
                label='Constrained Projection Displacement'
            )
            proj_displacement.fill_between(
                epochs_range,
                mean_loss_displacement - std_loss_displacement,
                mean_loss_displacement + std_loss_displacement,
                color='red',
                alpha=0.1
            )
            proj_displacement.set_yscale("log")  # Set y-axis to log scale
            proj_displacement.set_title("Projection Displacement")
            proj_displacement.set_xlabel("Epoch")
            proj_displacement.set_ylabel("Loss")
            proj_displacement.legend()
        else:
            proj_displacement.axis('off')

        # Save the plot
        plt.tight_layout()
        if SAVE:
            plt.savefig(os.path.join(output_dir, "plots", "losses_log_comparison.png"))
        # plt.close()

        ##
        # Dataset distribution plot
        x1_grid = np.linspace(LEFT_LIMIT, RIGHT_LIMIT, 1000)
        ncols = min(num_functions, 2)
        nrows = (num_functions + ncols - 1) // ncols

        dataset_distribution, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows))
        if num_functions == 1:
            axes = [axes]  # Ensure axes is iterable
        else:
            axes = axes.flatten()

        # UNCOMMENT THIS BLOCK TO PLOT TRUE FUNCTIONS (ONLY FOR 1D REGRESSION PROBLEMS)
        if len(functions) == 2:
            try:
                for i, func in enumerate(functions):
                    ax = axes[i]
                    # ax.plot(x1_grid, func(x1_grid), label='True Function')
                    ax.scatter(test_inputs_list_dict[modes[0]][0], test_outputs_list_dict[modes[0]][0][:, i], c="red", label='Test')
                    ax.scatter(train_inputs, train_outputs[:, i], c="blue", label='Training')
                    ax.set_title(f'y{i+1}')
                    ax.legend()
            except Exception as e:    
                print(f"Error plotting functions: {e}")
                # Handle the error (e.g., skip plotting or log the error)
                pass

        dataset_distribution.suptitle("Training Points and Predictions (Data is the Same Across Runs)")
        if SAVE:
            plt.savefig(os.path.join(output_dir, "plots", "dataset_distribution.png"))
        # plt.close()

        # Compute mean and std of predictions over runs for both constrained and unconstrained
        for mode in modes:
            test_predictions_array = np.array(test_predictions_list_dict[mode])  # Shape: [N, num_samples, num_functions]
            mean_predictions = np.mean(test_predictions_array, axis=0)
            std_predictions = np.std(test_predictions_array, axis=0)

            test_inputs = test_inputs_list_dict[mode][0]  # Same inputs across runs
            test_outputs = test_outputs_list_dict[mode][0]  # Same outputs across runs

            if mode == 'constrained':
                # For before projection, compute mean and std
                test_prediction_before_array = np.array(test_prediction_before_projection_list_dict[mode])
                mean_predictions_before = np.mean(test_prediction_before_array, axis=0)
                std_predictions_before = np.std(test_prediction_before_array, axis=0)
                # Number of subplots per function: 2
                ncols = 2
                nrows = num_functions
                fig2, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows))
                if num_functions == 1:
                    axes = np.array([axes])  # Ensure axes is 2D array
            else:
                # Number of subplots per function: 1
                ncols = min(num_functions, 2)
                nrows = (num_functions + ncols - 1) // ncols
                fig2, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows))
                if num_functions == 1:
                    axes = np.array([[axes]])  # Ensure axes is 2D array
                elif nrows == 1:
                    axes = np.expand_dims(axes, axis=0)

            axes = axes.reshape(-1, ncols)  # Ensure axes is 2D array

            # UNCOMMENT THIS BLOCK TO PLOT TRUE FUNCTIONS (ONLY FOR 1D REGRESSION PROBLEMS)
            if len(functions) == 2:
                try:
                    for i, func in enumerate(functions):
                        if mode == 'constrained':
                            # axes[i, 0] for prediction
                            # axes[i, 1] for before and after projection
                            ax_pred = axes[i, 0]
                            ax_pred.plot(x1_grid, func(x1_grid), label='True Function')
                            ax_pred.errorbar(test_inputs, mean_predictions[:, i], yerr=std_predictions[:, i], fmt='o', ecolor='gray', capsize=3, label=f"{mode.capitalize()} Prediction")
                            ax_pred.set_title(f"y{i+1} - Prediction")
                            ax_pred.legend()

                            ax_proj = axes[i, 1]
                            ax_proj.plot(x1_grid, func(x1_grid), label='True Function')
                            ax_proj.errorbar(test_inputs, mean_predictions_before[:, i], yerr=std_predictions_before[:, i], fmt='o', ecolor='orange', capsize=3, label=f"{mode.capitalize()} Prediction Before Projection")
                            ax_proj.errorbar(test_inputs, mean_predictions[:, i], yerr=std_predictions[:, i], fmt='o', ecolor='gray', capsize=3, label=f"{mode.capitalize()} Prediction After Projection")
                            ax_proj.set_title(f"y{i+1} - Before and After Projection")
                            ax_proj.legend()
                        else:
                            ax = axes.flatten()[i]
                            ax.plot(x1_grid, func(x1_grid), label='True Function')
                            ax.errorbar(test_inputs, mean_predictions[:, i], yerr=std_predictions[:, i], fmt='o', ecolor='gray', capsize=3, label=f"{mode.capitalize()} Prediction")
                            ax.set_title(f"y{i+1} - Prediction")
                            ax.legend()

                    fig2.suptitle(f"Test Prediction ({mode.capitalize()})")
                    plt.tight_layout()
                    if SAVE:
                        plt.savefig(os.path.join(output_dir, "plots", f"{mode}_prediction_projection.png"))
                    # plt.close()

                    # Constraint plot (for two functions, generalizes to pairwise combinations if more than two)
                    if num_functions >= 2:
                        from itertools import combinations
                        combs = list(combinations(range(num_functions), 2))
                        n_combs = len(combs)
                        ncols = min(n_combs, 2)
                        nrows = (n_combs + ncols - 1) // ncols
                        constraint_plot, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows))
                        axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]
                        for idx, (i, j) in enumerate(combs):
                            ax = axes[idx]
                            ax.plot(functions[i](x1_grid), functions[j](x1_grid), label='Constraint Curve')
                            ax.errorbar(mean_predictions[:, i], mean_predictions[:, j], xerr=std_predictions[:, i], yerr=std_predictions[:, j], fmt='o', ecolor='gray', capsize=3, label=f"{mode.capitalize()} Prediction")
                            ax.set_xlabel(f"y{i+1}")
                            ax.set_ylabel(f"y{j+1}")
                            ax.legend()
                        constraint_plot.suptitle(f"Constraints ({mode.capitalize()})")
                        plt.tight_layout()
                        if SAVE:
                            plt.savefig(os.path.join(output_dir, "plots", f"{mode}_constraint.png"))
                        # plt.close()
                    else:
                        # Skip constraint plot if only one function
                        pass
                except Exception as e:
                    print(f"Error plotting functions: {e}")
                    # Handle the error (e.g., skip plotting or log the error)
                    pass

            ## Parity plot
            ncols = min(num_functions, 2)
            nrows = (num_functions + ncols - 1) // ncols
            parity, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows))
            axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]

            for i in range(num_functions):
                ax = axes[i]
                ax.errorbar(test_outputs[:, i], mean_predictions[:, i], xerr=0, yerr=std_predictions[:, i], fmt='o', ecolor='gray', capsize=3)
                min_val = min(test_outputs[:, i].min(), mean_predictions[:, i].min())
                max_val = max(test_outputs[:, i].max(), mean_predictions[:, i].max())
                ax.plot([min_val, max_val], [min_val, max_val], 'r--')
                ax.set_xlabel('True')
                ax.set_ylabel('Prediction')
                ax.set_title(f'y{i+1}')

            parity.suptitle(f'Parity Plots ({mode.capitalize()})')
            plt.tight_layout()
            if SAVE:
                plt.savefig(os.path.join(output_dir, "plots", f"{mode}_parity_plots.png"))
            # plt.close()

        # Plot comparison between constrained and unconstrained if both modes are available
        if len(modes) == 2:
            # Parity plot comparison
            ncols = min(num_functions, 2)
            nrows = (num_functions + ncols - 1) // ncols
            parity_comparison, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows))
            axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]

            for i in range(num_functions):
                ax = axes[i]
                for mode, marker in zip(['constrained', 'unconstrained'], ['o', 's']):
                    test_outputs = test_outputs_list_dict[mode][0]
                    test_predictions_array = np.array(test_predictions_list_dict[mode])  # Shape: [N, num_samples, num_functions]
                    mean_predictions = np.mean(test_predictions_array, axis=0)
                    ax.scatter(test_outputs[:, i], mean_predictions[:, i], label=f"{mode.capitalize()}", marker=marker)
                min_val = min(test_outputs[:, i].min(), mean_predictions[:, i].min())
                max_val = max(test_outputs[:, i].max(), mean_predictions[:, i].max())
                ax.plot([min_val, max_val], [min_val, max_val], 'r--')
                ax.set_xlabel('True')
                ax.set_ylabel('Prediction')
                ax.set_title(f'y{i+1}')
                ax.legend()

            parity_comparison.suptitle('Parity Plots Comparison')
            plt.tight_layout()
            if SAVE:
                plt.savefig(os.path.join(output_dir, "plots", "parity_plots_comparison.png"))
            # plt.close()
        else:
            # If only one mode, no comparison plot
            pass

        # Show plots if needed
        # plt.show()
