from __future__ import annotations

import torch
from torch import Tensor
import numpy as np
from typing import Dict, Optional, Tuple

from .utils import softmax_with_temperature


class ControlGController:
    def __init__(
        self,
        K: int,
        device: str = "cpu",
        Kp: float = 1.0,
        Ki: float = 0.5,
        Kd: float = 0.1,
        rho_p: float = 0.5,
        rho_i: float = 0.5,
        rho_d: float = 0.5,
        I_max: float = 5.0,
        u_min: float = -10.0,
        u_max: float = 10.0,
        epsilon: float = 0.1,
        tau0: float = 1.0,
        tau_min: float = 0.1,
        tau_anneal: float = 100.0,
        stale_zeta: float = 0.5,
        stale_gamma: float = 0.9,
        stale_Smax: float = 2.0,
        use_staleness_term: bool = False,
        use_gain_scheduling: bool = False,
    ):
        self.K = K
        self.device = device
        self.Kp0 = Kp
        self.Ki0 = Ki
        self.Kd0 = Kd
        self.rho_p = rho_p
        self.rho_i = rho_i
        self.rho_d = rho_d
        self.I_max = I_max
        self.u_min = u_min
        self.u_max = u_max
        self.epsilon = epsilon
        self.tau0 = tau0
        self.tau_min = tau_min
        self.tau_anneal = tau_anneal
        self.stale_zeta = stale_zeta
        self.stale_gamma = stale_gamma
        self.stale_Smax = stale_Smax
        self.use_staleness_term = use_staleness_term
        self.use_gain_scheduling = use_gain_scheduling
        self.reset()
    
    def reset(self) -> None:
        self.N_realized = torch.zeros(self.K, device=self.device)
        self.I = torch.zeros(self.K, device=self.device)
        self.prev_e = torch.zeros(self.K, device=self.device)
        self.stale_counts = torch.zeros(self.K, device=self.device)
        self.block_idx = 0
        self.global_block = 0
    
    def set_global_block(self, global_block: int) -> None:
        self.global_block = global_block
    
    def get_temperature(self) -> float:
        tau = self.tau0 * np.exp(-self.global_block / self.tau_anneal)
        return max(self.tau_min, tau)
    
    def compute_gains(self, difficulty: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        if self.use_gain_scheduling:
            D = difficulty.to(self.device)
            Kp = self.Kp0 / (1 + self.rho_p * D)
            Ki = self.Ki0 / (1 + self.rho_i * D)
            Kd = self.Kd0 / (1 + self.rho_d * D)
        else:
            Kp = torch.full((self.K,), self.Kp0, device=self.device)
            Ki = torch.full((self.K,), self.Ki0, device=self.device)
            Kd = torch.full((self.K,), self.Kd0, device=self.device)
        return Kp, Ki, Kd
    
    def compute_staleness(self) -> Tensor:
        stale = 1.0 / (self.stale_counts + 1)
        stale = torch.clamp(stale, min=0.0, max=self.stale_Smax)
        return stale
    
    def update_staleness(self, selected_task: int) -> None:
        self.stale_counts = self.stale_gamma * self.stale_counts
        self.stale_counts[selected_task] += 1
    
    def step(self, allocation: Tensor, difficulty: Tensor, rng: np.random.RandomState) -> Tuple[int, Dict[str, Tensor]]:
        m = self.block_idx + 1
        K = self.K
        device = self.device
        f = allocation.to(device)
        D = difficulty.to(device)
        N_ref = f * m
        e = N_ref - self.N_realized
        Kp, Ki, Kd = self.compute_gains(D)
        self.I = torch.clamp(self.I + e, -self.I_max, self.I_max)
        delta_e = e - self.prev_e
        self.prev_e = e.clone()
        P_term = Kp * e
        I_term = Ki * self.I
        D_term = Kd * delta_e
        
        if self.use_staleness_term:
            stale = self.compute_staleness()
            stale_term = self.stale_zeta * stale
        else:
            stale = torch.zeros(K, device=device)
            stale_term = torch.zeros(K, device=device)
        
        u = P_term + I_term + D_term + stale_term
        u = torch.clamp(u, self.u_min, self.u_max)
        tau = self.get_temperature()
        p_soft = softmax_with_temperature(u, tau)
        p = (1 - self.epsilon) * p_soft + self.epsilon / K
        p_np = p.detach().cpu().numpy()
        p_np = p_np / p_np.sum()
        selected = int(rng.choice(K, p=p_np))
        self.N_realized[selected] += 1
        self.update_staleness(selected)
        self.block_idx += 1
        self.global_block += 1
        info = {
            "block": torch.tensor(m, device=device),
            "N_ref": N_ref.clone(),
            "N_realized": self.N_realized.clone(),
            "deficit": e.clone(),
            "integral": self.I.clone(),
            "derivative": delta_e.clone(),
            "staleness": stale.clone(),
            "P_term": P_term.clone(),
            "I_term": I_term.clone(),
            "D_term": D_term.clone(),
            "stale_term": stale_term.clone(),
            "logits": u.clone(),
            "temperature": torch.tensor(tau, device=device),
            "probabilities_soft": p_soft.clone(),
            "probabilities": p.clone(),
            "selected": torch.tensor(selected, device=device),
            "Kp": Kp.clone(),
            "Ki": Ki.clone(),
            "Kd": Kd.clone(),
        }
        return selected, info
    
    def get_realized_allocation(self) -> Tensor:
        total = self.N_realized.sum()
        if total < 1e-8:
            return torch.ones(self.K, device=self.device) / self.K
        return self.N_realized / total
