from __future__ import annotations

import torch
from torch import Tensor
from typing import Dict, List, Optional, Tuple, Union
import numpy as np


def project_simplex(v: Tensor, z: float = 1.0) -> Tensor:
    n = v.numel()
    if n == 0:
        return v
    u, _ = torch.sort(v, descending=True)
    cssv = torch.cumsum(u, dim=0)
    indices = torch.arange(1, n + 1, device=v.device, dtype=v.dtype)
    cond = u - (cssv - z) / indices
    rho = torch.sum(cond > 0).item()
    rho = max(1, int(rho))
    theta = (cssv[rho - 1] - z) / float(rho)
    w = torch.clamp(v - theta, min=0.0)
    return w


def rank_normalize(values: Tensor, eps: float = 1e-8) -> Tensor:
    n = values.numel()
    if n <= 1:
        return torch.zeros_like(values)
    sorted_vals, sorted_idx = torch.sort(values, descending=False)
    ranks_sorted = torch.empty_like(sorted_vals)
    i = 0
    while i < n:
        j = i + 1
        while j < n and torch.isclose(sorted_vals[j], sorted_vals[i]):
            j += 1
        avg_rank = (float(i + 1) + float(j)) / 2.0
        ranks_sorted[i:j] = avg_rank
        i = j
    ranks = torch.empty_like(values)
    ranks[sorted_idx] = ranks_sorted
    normalized = (ranks - 1.0) / (float(n - 1) + eps)
    return normalized


def zscore_normalize(values: Tensor, eps: float = 1e-8) -> Tensor:
    mean = values.mean()
    std = values.std() + eps
    return (values - mean) / std


class RunningStats:
    def __init__(self, K: int, device: str = "cpu"):
        self.K = K
        self.device = device
        self.count = torch.zeros(K, device=device)
        self.mean = torch.zeros(K, device=device)
        self.M2 = torch.zeros(K, device=device)
    
    def update(self, k: int, value: float) -> None:
        self.count[k] += 1
        delta = value - self.mean[k]
        self.mean[k] += delta / self.count[k]
        delta2 = value - self.mean[k]
        self.M2[k] += delta * delta2
    
    def get_mean(self, k: int) -> float:
        return float(self.mean[k].item())
    
    def get_std(self, k: int, eps: float = 1e-8) -> float:
        if self.count[k] < 2:
            return 1.0
        var = self.M2[k] / (self.count[k] - 1)
        return float(torch.sqrt(var + eps).item())
    
    def normalize(self, k: int, value: float, eps: float = 1e-8) -> float:
        mean = self.get_mean(k)
        std = self.get_std(k, eps)
        return (value - mean) / std


class EMATracker:
    def __init__(
        self,
        K: int,
        alpha: float = 0.1,
        v_min: Optional[float] = None,
        v_max: Optional[float] = None,
        init_value: float = 0.5,
        device: str = "cpu",
    ):
        self.K = K
        self.alpha = alpha
        self.v_min = v_min
        self.v_max = v_max
        self.device = device
        self.values = torch.full((K,), init_value, device=device)
        self.initialized = torch.zeros(K, dtype=torch.bool, device=device)
    
    def update(self, k: int, new_value: float) -> float:
        if not self.initialized[k]:
            self.values[k] = new_value
            self.initialized[k] = True
        else:
            self.values[k] = (1 - self.alpha) * self.values[k] + self.alpha * new_value
        if self.v_min is not None:
            self.values[k] = max(self.v_min, self.values[k].item())
        if self.v_max is not None:
            self.values[k] = min(self.v_max, self.values[k].item())
        return float(self.values[k].item())
    
    def update_all(self, new_values: Tensor) -> Tensor:
        not_init = ~self.initialized
        self.values[not_init] = new_values[not_init]
        self.initialized[not_init] = True
        init_mask = self.initialized
        self.values[init_mask] = (1 - self.alpha) * self.values[init_mask] + self.alpha * new_values[init_mask]
        if self.v_min is not None:
            self.values = torch.clamp(self.values, min=self.v_min)
        if self.v_max is not None:
            self.values = torch.clamp(self.values, max=self.v_max)
        return self.values.clone()
    
    def get(self, k: int) -> float:
        return float(self.values[k].item())
    
    def get_all(self) -> Tensor:
        return self.values.clone()


def compute_log_hypervolume(
    losses: Tensor,
    reference: Tensor,
    eps: float = 1e-8,
) -> float:
    gaps = torch.clamp(reference - losses, min=eps)
    log_hv = torch.sum(torch.log(gaps))
    return float(log_hv.item())


def softmax_with_temperature(logits: Tensor, tau: float = 1.0) -> Tensor:
    scaled = logits / max(tau, 1e-8)
    return torch.softmax(scaled, dim=0)
