import torch
import torch.nn.functional as F
import numpy as np

class CosineAnnealingScheduler:
    def __init__(self, max_steps, init_eps, lb_eps, init_sig, lb_sig, N, decay_rate=0.9):
        assert N >= 1, "N must be >= 1"
        self.max_steps = int(max_steps)
        self.init_eps = float(init_eps)
        self.lb_eps   = float(lb_eps)
        self.init_sig = float(init_sig)
        self.lb_sig   = float(lb_sig)
        self.N = int(N)
        self.decay_rate = float(decay_rate)
        self.T = max(1, self.max_steps // self.N)
        self.step_count = 0
        self.cycle = 0

    def step(self):
        self.step_count += 1
        if (self.step_count % self.T == 0) and (self.cycle < self.N - 1):
            self.cycle += 1
            self.init_eps *= self.decay_rate
            self.init_sig *= self.decay_rate

    def sample(self):
        t = (self.step_count % self.T) / self.T
        eps = self.lb_eps + 0.5 * (self.init_eps - self.lb_eps) * (1.0 + np.cos(np.pi * t))
        sig = self.lb_sig + 0.5 * (self.init_sig - self.lb_sig) * (1.0 + np.cos(np.pi * t))
        return float(eps), float(sig)
    
class Sampler:
    def __init__(self, scheduler: CosineAnnealingScheduler, q_min: float, q_max: float):
        self.scheduler = scheduler
        self.q_min = q_min
        self.q_max = q_max

    def sample(self, ref_probs_like: torch.Tensor) -> torch.Tensor:
        eps, std = self.scheduler.sample()
        self.scheduler.step()
        with torch.no_grad():
            keep = (torch.rand_like(ref_probs_like) > eps).to(ref_probs_like.dtype)
            noise = torch.normal(
                mean=1,
                std = std,
                size=ref_probs_like.shape,
                device=ref_probs_like.device
            )
            q = keep * 1.0 + (1.0 - keep) * noise
            q = q.clamp(self.q_min, self.q_max) 
            return q
        
class SurprisalAwareSampler:
    def __init__(
        self,
        scheduler: CosineAnnealingScheduler,
        use_source: str = "new",
        norm: str = "rank",
        q_min: float = 0.2,
        q_max: float = 1.8,
        detach_surprisal: bool = True,
    ):
        assert use_source in ("new", "ref")
        assert norm in ("rank", "minmax", "zscore")
        self.use_source = use_source
        self.scheduler = scheduler
        self.norm = norm
        self.q_min = q_min
        self.q_max = q_max
        self.detach_surprisal = detach_surprisal

    @torch.no_grad()
    def _normalize(self, s: torch.Tensor, mask: torch.Tensor):
        if self.norm == "rank":
            w = torch.zeros_like(s)
            for b in range(s.size(0)):
                valid = mask[b] > 0
                if valid.any():
                    vals = s[b, valid]
                    ranks = torch.argsort(torch.argsort(vals))
                    w[b, valid] = ranks.float() / max(1, (valid.sum()-1))
            return w.clamp(0, 1)
        elif self.norm == "minmax":
            w = torch.zeros_like(s)
            s_clipped = torch.clamp(s, s.quantile(0.05), s.quantile(0.95))
            for b in range(s.size(0)):
                valid = mask[b] > 0
                if valid.any():
                    vals = s[b, valid]
                    mn, mx = vals.min(), vals.max()
                    if (mx - mn) > 1e-6:
                        w[b, valid] = (vals - mn) / (mx - mn)
                    else:
                        w[b, valid] = 0.5
            return w.clamp(0, 1)
        else: 
            w = torch.zeros_like(s)
            for b in range(s.size(0)):
                valid = mask[b] > 0
                if valid.any():
                    vals = s[b, valid]
                    mu = vals.mean()
                    sd = vals.std(unbiased=False).clamp_min(1e-6)
                    z = (vals - mu) / sd
                    w[b, valid] = torch.sigmoid(z)
            return w.clamp(0, 1)

    def sample(self, *, new: torch.Tensor, ref: torch.Tensor, mask: torch.Tensor):
        with torch.no_grad():
            alpha, sigma = self.scheduler.sample()
            self.scheduler.step()
            src = new if self.use_source == "new" else ref
            s = (-src).float()                      
            if self.detach_surprisal:
                s = s.detach()

            w = self._normalize(s, mask)           

            mu = (1.0 + alpha * w)            
            sigma = (sigma* w)              

            eps = torch.randn_like(mu)
            q = mu + sigma * eps
            q = torch.clamp(q, self.q_min, self.q_max)

            q = q * mask + (1.0 - mask)
            return q

class EntropyCompressor:
    def __init__(
        self,
        threshold: float,
        ratio: float = 1.0,
        use_source: str = "new",   # "new" | "ref"
        q_min: float = 0.6,
        q_max: float = 1.4,
    ):
        assert use_source in ("new", "ref")
        self.threshold = threshold
        self.ratio = ratio
        self.use_source = use_source
        self.q_min = q_min
        self.q_max = q_max

    @torch.no_grad()
    def sample(
        self,
        *,
        new_logits: torch.Tensor,  # (B, T, V)
        ref_logits: torch.Tensor,   # (B, T, V)
        mask: torch.Tensor          # (B, T)
    ):

        logits = new_logits if self.use_source == "new" else ref_logits

        probs = F.softmax(logits.float(), dim=-1)             # (B, T, V)
        entropy = -(probs * (probs.clamp_min(1e-12).log())).sum(dim=-1)  # (B, T)
        entropy = entropy.detach()

        q = 1.0 + self.ratio * torch.relu(entropy - self.threshold)
        q = torch.clamp(q, self.q_min, self.q_max)

        q = q * mask + (1.0 - mask)

        return q
