"""
ER-ACE + CLFD (NeurIPS 2024): Frequency-domain replay-friendly inputs.
Paper: "Continual Learning in the Frequency Domain (CLFD)", NeurIPS 2024.
"""

from typing import Callable, List, Optional, Union
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, Module
from torch.optim import Optimizer

from avalanche.core import SupervisedPlugin
from avalanche.models.utils import avalanche_forward
from avalanche.training import ACECriterion
from avalanche.training.plugins.evaluation import EvaluationPlugin, default_evaluator
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.utils import cycle

from MERS.mers_utils.storage_policy import ExemplarsBuffer
from avalanche.training.plugins import GSS_greedyPlugin

# ---- simple, dependency-free 2D DCT (type-II) and low-pass keep ----
def _dct_1d(x):
    # x: (..., N)
    N = x.shape[-1]
    k = torch.arange(N, device=x.device).float()
    n = k.view(1, -1)
    k = k.view(-1, 1)
    W = torch.cos((torch.pi / N) * (n + 0.5) * k)
    # orthogonal scaling for k=0 vs others
    alpha = torch.ones(N, device=x.device)
    alpha[0] = 1.0 / torch.sqrt(torch.tensor(2.0, device=x.device))
    return torch.matmul(x, W) * torch.sqrt(torch.tensor(2.0 / N, device=x.device)) * alpha

def dct2(x):
    # x: (B,C,H,W)
    x = _dct_1d(x.transpose(-1, -2)).transpose(-1, -2)
    x = _dct_1d(x)
    return x

def idct2(x):
    # inverse using explicit inverse of DCT-II (DCT-III)
    N = x.shape[-1]
    k = torch.arange(N, device=x.device).float()
    n = k.view(1, -1)
    k = k.view(-1, 1)
    W = torch.cos((torch.pi / N) * k * (n + 0.5))
    alpha = torch.ones(N, device=x.device)
    alpha[0] = 1.0 / torch.sqrt(torch.tensor(2.0, device=x.device))
    x = x / (torch.sqrt(torch.tensor(2.0 / N, device=x.device)) * alpha)
    x = torch.matmul(x, W)
    x = torch.matmul(x.transpose(-1, -2), W).transpose(-1, -2)
    return x

def clfd_lowpass(x, keep_ratio: float = 0.5):
    # Keep a top-left low-frequency square of size ~ keep_ratio
    B, C, H, W = x.shape
    X = dct2(x)
    kh = max(1, int(H * keep_ratio))
    kw = max(1, int(W * keep_ratio))
    mask = torch.zeros_like(X)
    mask[:, :, :kh, :kw] = 1.0
    X = X * mask
    return idct2(X)

class ER_ACE_Clfd(SupervisedTemplate):
    """
    ER-ACE augmented with CLFD-style low-frequency inputs.
    NeurIPS 2024 reports gains & training efficiency with rehearsal methods.
    """
    def __init__(
        self,
        model: Module,
        optimizer: Optimizer,
        criterion=CrossEntropyLoss(),
        mem_size: int = 200,
        batch_size_mem: int = 10,
        train_mb_size: int = 1,
        train_epochs: int = 1,
        eval_mb_size: Optional[int] = 1,
        device: Union[str, torch.device] = "cpu",
        storage_policy: ExemplarsBuffer = None,
        plugins: Optional[List[SupervisedPlugin]] = None,
        evaluator: Union[EvaluationPlugin, Callable[[], EvaluationPlugin]] = default_evaluator,
        eval_every=-1,
        peval_mode="epoch",
        args=None,
        use_gss: bool = False,
        clfd_keep_ratio: float = 0.5,  # how much low-frequency content to keep
    ):
        self.mem_size = mem_size
        self.batch_size_mem = batch_size_mem
        self.replay_loader = None
        self.ace_criterion = ACECriterion()
        self.args = args
        self.use_gss = use_gss
        self.clfd_keep_ratio = clfd_keep_ratio

        # Base ER-ACE plugin (class-balanced buffer)
        er_ace = ER_ACE_Plugin(storage_policy=storage_policy, mem_size=mem_size)
        if plugins is None:
            plugins = [er_ace]
        else:
            plugins.append(er_ace)
        if use_gss:
            gss_plugin = GSS_greedyPlugin(
                mem_size=mem_size,
                mem_strength=getattr(args, 'gss_mem_strength', 20),
                input_size=getattr(args, 'gss_input_size', [3, 32, 32]),
            )
            plugins.append(gss_plugin)

        super().__init__(model, optimizer, criterion, train_mb_size, train_epochs,
                         eval_mb_size, device, plugins, evaluator, eval_every, peval_mode)

    def _clfd(self, x):
        # Normalize to [0,1] if needed, then low-pass
        if x.dtype.is_floating_point:
            xn = x
        else:
            xn = x.float() / 255.0
        return clfd_lowpass(xn, self.clfd_keep_ratio)

    def training_epoch(self, **kwargs):
        for self.mbatch in self.dataloader:
            if self._stop_training:
                break

            self._unpack_minibatch()
            self._before_training_iteration(**kwargs)

            have_replay = False
            if self.replay_loader is not None:
                self.mb_buffer_x, self.mb_buffer_y, self.mb_buffer_tid = next(self.replay_loader)
                self.mb_buffer_x = self.mb_buffer_x.to(self.device)
                self.mb_buffer_y = self.mb_buffer_y.to(self.device)
                self.mb_buffer_tid = self.mb_buffer_tid.to(self.device)
                have_replay = True

            # ---- NO assignment to self.mb_x! build transformed tensors instead ----
            x_cur = self._clfd(self.mb_x.to(self.device))
            x_rep = self._clfd(self.mb_buffer_x) if have_replay else None

            self.optimizer.zero_grad()
            self.loss = self._make_empty_loss()

            # Forward
            self._before_forward(**kwargs)
            # use avalanche_forward directly so we can pass our transformed inputs
            self.mb_output = avalanche_forward(self.model, x_cur, self.mb_task_id)
            if have_replay:
                self.mb_buffer_out = avalanche_forward(self.model, x_rep, self.mb_buffer_tid)
            self._after_forward(**kwargs)

            # Loss
            if have_replay:
                self.loss += self.ace_criterion(
                    self.mb_output, self.mb_y,
                    self.mb_buffer_out, self.mb_buffer_y
                )
            else:
                self.loss += self.criterion()

            # Backward / step
            self._before_backward(**kwargs)
            self.backward()
            if hasattr(self.args, 'sel_strategy') and self.args.sel_strategy in ('rm', 'gss'):
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self._after_backward(**kwargs)

            self._before_update(**kwargs)
            self.optimizer_step()
            self._after_update(**kwargs)

            self._after_training_iteration(**kwargs)

    def _before_training_exp(self, **kwargs):
        self.plugins[1].storage_policy.update(self, **kwargs)
        buffer = self.plugins[1].storage_policy.buffer
        if len(buffer) >= self.batch_size_mem and self.experience.current_experience > 0:
            self.replay_loader = cycle(
                torch.utils.data.DataLoader(buffer, batch_size=self.batch_size_mem, shuffle=True, drop_last=True)
            )
        super()._before_training_exp(**kwargs)

    def forward(self):
        # read-only; safe to access
        x = self.mb_x.to(self.device)
        x = self._clfd(x)  # same transform for both train & eval
        return avalanche_forward(self.model, x, self.mb_task_id)
    def _train_cleanup(self):
        super()._train_cleanup()
        self.replay_loader = None


class ER_ACE_Plugin(SupervisedPlugin):
    def __init__(self, storage_policy: Optional["ExemplarsBuffer"] = None, mem_size: int = 2000):
        super().__init__()
        self.mem_size = mem_size
        self.storage_policy = storage_policy if storage_policy is not None else ClassBalancedBuffer(
            max_size=self.mem_size, adaptive_size=True
        )
