from typing import Dict
import numpy as np
import torch
import torch.nn.functional as F

from context_general_bci.tasks.preproc_utils import apply_minmax_norm, unapply_minmax_norm
from context_general_bci.rtndt.config import OnlineConfig

class Rescaler:
    r"""
        Online analog of `preproc_utils.apply_minmax_norm`
    """
    def __init__(self, payload: Dict[str, torch.Tensor], max_dims: int):
        self.scale_params = {
            'cov_mean': pad_if_not_none(payload.get('cov_mean', None), pad_to=max_dims),
            'cov_max': pad_if_not_none(payload.get('cov_max', None), pad_to=max_dims),
            'cov_min': pad_if_not_none(payload.get('cov_min', None), pad_to=max_dims),
        }
        # assert at least one not none or warn
        if all(x is None for x in self.scale_params.values()):
            print("Warning: No scale parameters found in payload")

    # Let's make this in active dims
    def normalize(self, covariates):
        # print(f"Normalize: {covariates[-1]}")
        return apply_minmax_norm(covariates, self.scale_params)[0].numpy()

    def unnormalize(self, covariates):
        return unapply_minmax_norm(covariates, self.scale_params).numpy()


def pad_if_not_none(x: torch.Tensor, pad_to: int):
    if x is None:
        return None
    return F.pad(x, (0, pad_to - x.size(0))).numpy()

def roll_and_pad(tensor: np.ndarray, bins_elapsed: int, pad=0, new_value=0):
    # tensor: time x *
    tensor = np.roll(tensor, -bins_elapsed, 0)
    if bins_elapsed > 1:
        tensor[-bins_elapsed:-1] = pad
    tensor[-1] = new_value
    return tensor

def nucleus_filter(logits, nucleus_p):
    probs = np.exp(logits)
    probs /= np.sum(probs)
    sorted_indices = np.argsort(probs)[::-1]
    sorted_probs = probs[sorted_indices]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff_index = np.searchsorted(cumulative_probs, nucleus_p)
    top_indices = sorted_indices[:cutoff_index + 1]
    return top_indices

def sample_from_logits(logits: np.ndarray, temperature=1.0, nucleus_p=0.9) -> int:
    """
    Sample from logits with temperature and nucleus (top-p) parameters.

    :param logits: Array of logits.
    :param temperature: Temperature parameter to adjust sharpness of distribution.
    :param nucleus_p: Nucleus (top-p) parameter for truncating the distribution.
    :return: Index of the sampled element.
    """
    # Adjust logits with temperature
    logits = logits / temperature

    # Convert logits to probabilities
    probs = np.exp(logits)
    probs /= np.sum(probs)

    # Sort probabilities and filter out the tail
    sorted_indices = np.argsort(probs)[::-1]
    sorted_probs = probs[sorted_indices]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff_index = np.searchsorted(cumulative_probs, nucleus_p)
    top_indices = sorted_indices[:cutoff_index + 1]

    # Renormalize the truncated distribution
    top_probs = probs[top_indices]
    top_probs /= np.sum(top_probs)

    # Sample from the truncated distribution
    sampled_index = np.random.choice(top_indices, p=top_probs)
    return sampled_index
