import torch
import numpy as np
import matplotlib.pyplot as plt
from utils import get_circular_error
import pandas as pd  

def plot_decoded_results(labels_test, decoded_label_mode, decoded_label_mean,label_params):
    """
    Visualize decoding performance via absolute circular error and
    true vs. decoded angle scatter plots for both mode and mean decoding.

    Parameters
    ----------
    labels_test : torch.Tensor, shape [N]
        Ground-truth labels (angles in radians).
    decoded_label_mode : torch.Tensor, shape [N]
        Decoded labels via maximum a posteriori (mode).
    decoded_label_mean : torch.Tensor, shape [N]
        Decoded labels via posterior mean (circular mean if applicable).
    label_params : dict
        Dictionary containing label parameters, including:
        - 'is_label_range_circular': bool
            True if the label range is circular (e.g., angles).
    """

    # 1) Compute absolute circular error for mode and mean decodings
    if label_params['is_label_range_circular']:
        # For circular labels, use the circular error function
        abs_error_mode = get_circular_error(labels_test, decoded_label_mode)
        abs_error_mean = get_circular_error(labels_test, decoded_label_mean)
    else:
        # For non-circular labels, use the standard absolute error  
        abs_error_mode = torch.abs(labels_test - decoded_label_mode)
        abs_error_mean = torch.abs(labels_test - decoded_label_mean)


    # 2) Move errors and decoded labels to NumPy and convert from radians to degrees
    abs_error_mode_degrees = abs_error_mode.detach().cpu().numpy() * 180 / np.pi
    decoded_label_mode_degrees = decoded_label_mode.detach().cpu().numpy() * 180 / np.pi
    abs_error_mean_degrees = abs_error_mean.detach().cpu().numpy() * 180 / np.pi
    decoded_label_mean_degrees = decoded_label_mean.detach().cpu().numpy() * 180 / np.pi

    # 3) Convert true labels to degrees for plotting
    labels_test_degrees = labels_test.detach().cpu().numpy() * 180 / np.pi

    # 4) Summarize error statistics in a DataFrame
    stats = []
    for method_name, errs in [
        ('mode', abs_error_mode_degrees),
        ('mean', abs_error_mean_degrees)
    ]:
        stats.append({
            'Method': method_name,
            'MeanErr': errs.mean(),
            'StdErr': errs.std(),
            'MedianErr': np.median(errs),
            'IQR': np.percentile(errs, 75) - np.percentile(errs, 25),
        })
    df_stats = pd.DataFrame(stats).set_index('Method')
    print("Decoding error summary (degrees):\n", df_stats)

    # 5) Plot mode-decoding results: histogram and true-vs-decoded scatter
    fig, axs = plt.subplots(1, 2, figsize=(9, 3))

    # 5a) Histogram of absolute error (mode)
    axs[0].hist(
        abs_error_mode_degrees,
        bins=180,
        color='magenta',
        alpha=0.7
    )
    axs[0].set_xlabel("Absolute Circular Error (°)")
    axs[0].set_ylabel("Frequency")
    axs[0].set_title("Mode Decoding: Absolute Error Histogram")
    axs[0].grid(True)

    # 5b) Scatter plot: true vs. decoded (mode)
    axs[1].scatter(
        labels_test_degrees,
        decoded_label_mode_degrees,
        alpha=0.6,
        label="Mode",
        color='magenta'
    )
    # Identity line for perfect decoding
    min_deg = labels_test_degrees.min()
    max_deg = labels_test_degrees.max()
    axs[1].plot(
        [min_deg, max_deg],
        [min_deg, max_deg],
        color='black',
        linestyle='--',
        label="Ideal"
    )
    axs[1].set_xlabel("True Angle (°)")
    axs[1].set_ylabel("Decoded Angle (°)")
    axs[1].set_xlim(min_deg, max_deg)
    axs[1].set_title("Mode Decoding: True vs. Decoded")
    axs[1].legend()
    axs[1].grid(True)

    plt.tight_layout()
    plt.show()

    # 6) Plot mean-decoding results: histogram and true-vs-decoded scatter
    fig, axs = plt.subplots(1, 2, figsize=(9, 3))

    # 6a) Histogram of absolute error (mean)
    axs[0].hist(
        abs_error_mean_degrees,
        bins=180,
        color='red',
        alpha=0.7
    )
    axs[0].set_xlabel("Absolute Circular Error (°)")
    axs[0].set_ylabel("Frequency")
    axs[0].set_title("Mean Decoding: Absolute Error Histogram")
    axs[0].grid(True)

    # 6b) Scatter plot: true vs. decoded (mean)
    axs[1].scatter(
        labels_test_degrees,
        decoded_label_mean_degrees,
        alpha=0.6,
        label="Mean",
        color='red'
    )
    axs[1].plot(
        [min_deg, max_deg],
        [min_deg, max_deg],
        color='black',
        linestyle='--',
        label="Ideal"
    )
    axs[1].set_xlabel("True Angle (°)")
    axs[1].set_ylabel("Decoded Angle (°)")
    axs[1].set_xlim(min_deg, max_deg)
    axs[1].set_title("Mean Decoding: True vs. Decoded")
    axs[1].legend()
    axs[1].grid(True)

    plt.tight_layout()
    plt.show()