from __future__ import annotations
import torch
from torch import Tensor
from typing import Dict, Optional
from collections import deque

class ControlGPlanner:
    def __init__(self, K: int, device: str = "cpu", ref_delta: float = 0.1,
                 ref_window: int = 5, ref_eps: float = 0.01, gamma: float = 1.0, f_min: float = 0.05):
        self.K, self.device = K, device
        self.ref_delta, self.ref_window, self.ref_eps = ref_delta, ref_window, ref_eps
        self.gamma, self.f_min = gamma, f_min
        self.loss_history = deque(maxlen=ref_window)
        self.reference, self.allocation = None, None
        self.epoch_idx, self.last_plan_info = 0, {}
    
    def compute_reference_points(self, current_losses: Tensor) -> Tensor:
        current_losses = current_losses.to(self.device)
        if len(self.loss_history) == 0:
            ref = (1 + self.ref_delta) * current_losses
        else:
            stacked = torch.stack(list(self.loss_history), dim=0)
            max_losses = torch.maximum(stacked.max(dim=0).values.to(self.device), current_losses)
            ref = (1 + self.ref_delta) * max_losses
        ref = torch.maximum(ref, current_losses + self.ref_delta)
        if self.reference is not None:
            ref = torch.maximum(ref, self.reference.to(self.device))
        return ref
    
    def plan(self, normalized_losses: Tensor, difficulty: Tensor) -> Tensor:
        losses, D = normalized_losses.to(self.device), difficulty.to(self.device)
        self.reference = self.compute_reference_points(losses)
        gaps = torch.clamp(self.reference - losses, min=1e-8)
        w_hv = 1.0 / gaps
        tempered = w_hv / (1 + self.gamma * D)
        base_alloc = tempered / (tempered.sum() + 1e-8)
        floor_mass = self.f_min * self.K
        allocation = torch.ones(self.K, device=self.device) / self.K if floor_mass >= 1.0 else (self.f_min + (1.0 - floor_mass) * base_alloc)
        self.allocation = allocation
        self.last_plan_info = {"reference": self.reference.clone(), "gaps": gaps.clone(), "w_hv": w_hv.clone(),
                               "tempered": tempered.clone(), "allocation": allocation.clone(), "difficulty": D.clone()}
        self.loss_history.append(losses.detach().clone())
        self.epoch_idx += 1
        return allocation
    
    def get_allocation(self) -> Tensor:
        return self.allocation.clone() if self.allocation is not None else torch.ones(self.K, device=self.device) / self.K
    
    def get_log_hv(self, losses: Optional[Tensor] = None) -> float:
        if self.reference is None: return 0.0
        if losses is None: losses = self.loss_history[-1] if self.loss_history else torch.zeros(self.K)
        return float(torch.sum(torch.log(torch.clamp(self.reference - losses, min=1e-8))).item())
