import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from typing import Dict, Any, Tuple

from utils import get_grid_uniform, get_fdprs,get_circular_error


def decode_labels_CMLR(
    CMLR_model: torch.nn.Module,
    J_decode: int,
    xsamples_test: torch.Tensor,
    label_params: Dict[str, Any]
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Decode continuous labels using a trained CMLR model.

    This function computes both the MAP (mode) and circular/Euclidean mean
    estimates of the labels given test feature samples.

    Steps:
      1. Build a uniform grid of label values.
      2. Compute Fourier basis on grid and reconstruct weight functions.
      3. Compute logits and softmax posterior over grid points.
      4. Mode decoding: choose grid location with highest posterior.
      5. Mean decoding:
         - If labels are circular, compute circular mean via weighted sin/cos.
         - Otherwise, compute standard (Euclidean) mean under posterior.

    Parameters
    ----------
    CMLR_model : nn.Module
        Trained CMLR model instance containing:
          - mu_q: Tensor [Bdim, D], posterior mean of Fourier weights.
    J_decode : int
        Number of uniform grid points for decoding.
    xsamples_test : Tensor [num_trials, D]
        Test feature matrix (neural responses) to decode labels for.
    label_params : dict
        Configuration dict with keys:
          - 'no_of_inputs': 1 or 2 (dimensionality of label space)
          - 'dimension_1_range': [min, max] for axis y1
          - 'dimension_2_range': [min, max] for axis y2 (if 2D)
          - 'is_label_range_circular': bool, whether label space is circular

    Returns
    -------
    decoded_label_mode : Tensor [num_trials, dims]
        Decoded label per trial via MAP (highest posterior grid point).
    decoded_label_mean : Tensor [num_trials, dims]
        Decoded label per trial via posterior mean (circular or linear).
    """
    # --- 1) Create uniform decoding grid and Fourier basis ---
    grid_decode = get_grid_uniform(J_decode, label_params) # [J, dims]
    fdprs_decode = get_fdprs(grid_decode, label_params)# Get the frequency domain parameters for decoding
    Bmat_decode = fdprs_decode['Bmat'] # [J, Bdim]

    # --- 2) Reconstruct weight functions on grid ---
    mu_q = CMLR_model.mu_q.detach()            # [Bdim, D]
    w_decode = Bmat_decode.float() @ mu_q.float()       # [J, D]

    # --- 3) Compute logits and posterior over grid for each trial ---
    logits = xsamples_test @ w_decode.T      # [num_trials, J]
    posterior = F.softmax(logits, dim=1)     # [num_trials, J]

    # --- 4) Mode decoding: pick grid point with highest posterior ---
    assigned_class = torch.argmax(posterior, dim=1)  # [num_trials]
    decoded_label_mode = grid_decode[assigned_class].squeeze(1)  # [num_trials, dims]

    # --- 5) Mean decoding: compute posterior mean ---
    # If the label range is circular (0, 2*pi), we need to compute the circular mean
    if label_params['is_label_range_circular'] == True:
        # For circular labels, compute angle mean via atan2
        grid_decode_flat = grid_decode.squeeze(1)
        # Compute the weighted sum of sine and cosine for each trial:
        # Each element in `posterior` weights the corresponding sine and cosine of grid_decode.
        weighted_sin = torch.sum(posterior * torch.sin(grid_decode_flat), dim=1)
        weighted_cos = torch.sum(posterior * torch.cos(grid_decode_flat), dim=1)
        # Compute the circular mean using atan2. This gives an angle in the range [-pi, pi]
        decoded_label_mean = torch.atan2(weighted_sin, weighted_cos)
        # Map to te interval [0, 2*pi)
        decoded_label_mean = (decoded_label_mean + 2 * np.pi) % (2 * np.pi) # shape [num_trials, 1]
        # For linear labels, compute posterior-weighted average
    else:
        decoded_label_mean =  posterior @ grid_decode  # shape [num_trials, dims]

    return decoded_label_mode, decoded_label_mean


def get_avg_decoding_error_vs_resolution_CMLR(CMLR_model, J_decode_range, xsamples_test, labels_test, label_params):
    """
    Compute the average decoding error for different decoding resolutions.

    Parameters
    ----------
    CMLR_model : CMLRModel
        The trained CMLR model.
    J_decode_range : np.ndarray
        Array of decoding resolutions.
    xsamples_test : torch.Tensor
        Test samples.
    labels_test : torch.Tensor
        True labels for the test samples.
    label_params : dict
        Label parameters for the model.

    Returns
    -------
    avg_errors_mode_CMLR : list
        List of average errors for mode decoding.
    avg_errors_mean_CMLR : list
        List of average errors for mean decoding.
    """

    # List to store the mean absolute error for each K_decode.
    avg_errors_mode_CMLR = [] # for mode decoding
    avg_errors_mean_CMLR = [] # for mean decoding

    # Loop over the decoding resolutions.
    for J_decode in J_decode_range:
        # Decode the labels using the CMLR model.
        decoded_label_mode_CMLR,decoded_label_mean_CMLR = decode_labels_CMLR(CMLR_model, J_decode, xsamples_test, label_params)
        if label_params['no_of_outputs'] == 1:
            if label_params['is_label_range_circular']:
                # Compute the absolute circular error between the true angles and the decoded angles for mode decoding.
                abs_error_mode_CMLR = get_circular_error(labels_test, decoded_label_mode_CMLR)
            else:
                # Compute the absolute error between the true angles and the decoded angles for mode decoding.
                abs_error_mode_CMLR = torch.abs(labels_test - decoded_label_mode_CMLR)
        elif label_params['no_of_outputs'] == 2:
            abs_error_mode_CMLR = torch.norm(decoded_label_mode_CMLR - labels_test, dim=1) 
        # Compute the mean error across all trials for the current J_decode.
        avg_error_mode_CMLR = torch.mean(abs_error_mode_CMLR).item()
        # Append the average error to the list.
        avg_errors_mode_CMLR.append(avg_error_mode_CMLR)
        if label_params['no_of_outputs'] == 1:
            if label_params['is_label_range_circular']:
                # Compute the absolute circular error between the true angles and the decoded angles for mode decoding.
                abs_error_mean_CMLR = get_circular_error(labels_test, decoded_label_mean_CMLR)
            else:
                # Compute the absolute error between the true angles and the decoded angles for mode decoding.
                abs_error_mean_CMLR = torch.abs(labels_test - decoded_label_mean_CMLR)
        elif label_params['no_of_outputs'] == 2:
            abs_error_mean_CMLR = torch.norm(decoded_label_mean_CMLR - labels_test, dim=1)
        # Compute the mean error across all trials for the current J_decode.
        avg_error_mean_CMLR = torch.mean(abs_error_mean_CMLR).item()
        # Append the average error to the list.
        avg_errors_mean_CMLR.append(avg_error_mean_CMLR)
    # Convert the lists to NumPy arrays and convert from radians to degrees.    
    avg_errors_mode_CMLR = np.array(avg_errors_mode_CMLR)
    avg_errors_mean_CMLR = np.array(avg_errors_mean_CMLR)

    return avg_errors_mode_CMLR, avg_errors_mean_CMLR
