# generative_replay.py
from __future__ import annotations
from typing import Sequence, Optional, Tuple, Union, Literal
import torch
from torch import Tensor
from tqdm import tqdm

LabelPolicy = Literal["round_robin", "balanced", "uniform"]


class GenerativeReplay:
    """
    - Call `replay(policy=...)` to get a *balanced* batch of (x_old, y_old)
      according to the chosen label policy.
    - Expects `teacher.sample(..., save=None)` to return a *tensor* batch
      shaped [B, C, H, W].

    Args:
        teacher: frozen generator from the end of task (t-1) with .sample(...)
        old_classes: iterable of old class IDs (e.g., [0,1,2,3])
        batch_size: your training batch size (used to compute replay size)
        alpha: fraction of *new* real data you plan to mix externally
               (replay size is n_old = round((1 - alpha) * batch_size))
        num_inference_steps, eta, seed, device, guidance_scale:
            forwarded to teacher.sample(...)
    """

    def __init__(
        self,
        teacher,
        old_classes: Sequence[int],
        batch_size: int,
        alpha: float = 0.5,
        pool_size_per_class: int = 5000,
        num_inference_steps: int = 50,
        eta: float = 0.0,
        seed: Optional[int] = None,
        device: Optional[Union[str, torch.device]] = None,
    ) -> None:
        self.teacher = teacher
        self.batch_size = int(batch_size)
        self.alpha = float(alpha)
        self.num_inference_steps = int(num_inference_steps)
        self.eta = float(eta)
        self.seed = seed
        self.pool_size_per_class = int(pool_size_per_class)

        # resolve device
        if device is None:
            try:
                device = next(self.teacher.unet.parameters()).device
            except Exception:
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device(device)

        # store classes & round-robin rotor
        self.old_classes = torch.tensor(sorted(set(int(c) for c in old_classes)), dtype=torch.long)
        self._rotor = 0  # advances so remainders get shared across calls

        self.build_pool()
    
    @torch.no_grad()
    def build_pool(self):
        # pre-generated pool of images for replay
        self.pool = []
        self.pool_labels = []
        bs = 500  # generate in batches to avoid OOM
        for class_id in tqdm(self.old_classes.tolist()):
            n_samples = 0
            while n_samples < self.pool_size_per_class:
                seed = self.seed + n_samples
                y_old = torch.full((bs,), class_id, dtype=torch.long, device=self.device)
                x_old = self.teacher.sample(
                    batch_size=bs,
                    labels=y_old,
                    num_inference_steps=self.num_inference_steps,
                    eta=self.eta,
                    save=None,                  # must return a tensor batch
                    seed=seed,            
                    device=self.device,
                )
                self.pool.append(x_old)
                self.pool_labels.append(y_old)
                n_samples += bs
        self.pool = torch.cat(self.pool, dim=0)
        self.pool_labels = torch.cat(self.pool_labels, dim=0)
        self.pool_size = self.pool.shape[0]
        print(f"GenerativeReplay: built pool of {self.pool_size} images for replay.")
        self.pool_indices = torch.randperm(self.pool_size, device=self.device)   
        self.pool_ptr = 0     

    def set_old_classes(self, old_classes: Sequence[int]) -> None:
        """Update the set of old classes (e.g., at the start of a new task)."""
        self.old_classes = torch.tensor(sorted(set(int(c) for c in old_classes)), dtype=torch.long)
        self._rotor = 0

    def update_teacher(self, new_teacher, old_classes: Optional[Sequence[int]] = None) -> None:
        """Swap in a new frozen teacher (e.g., after finishing the current task)."""
        self.teacher = new_teacher
        if old_classes is not None:
            self.set_old_classes(old_classes)
        # reset pool
        self.build_pool()

    def n_old(self) -> int:
        """Number of replay samples given (batch_size, alpha)."""
        return int(self.alpha * self.batch_size)

    @torch.no_grad()
    def replay(self):
        # randomly sample from the pre-generated pool
        n_replay_samples = self.n_old()
        idx = self.pool_indices[self.pool_ptr:self.pool_ptr + n_replay_samples]
        # x_old = self.pool[self.pool_ptr:self.pool_ptr + n_replay_samples]
        x_old = self.pool[idx]
        # y_old = self.pool_labels[self.pool_ptr:self.pool_ptr + n_replay_samples]
        y_old = self.pool_labels[idx]
        self.pool_ptr += n_replay_samples
        if self.pool_ptr + n_replay_samples > self.pool_size:
            self.pool_indices = torch.randperm(self.pool_size, device=self.device)   
            self.pool_ptr = 0
        return x_old, y_old