import logging
from typing import List, Optional, Tuple
import math
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

from models.base import BaseLearner
from utils.inc_net import IncrementalNet

num_workers = 8


def rsvd(A, rank, oversampling, n_iter=2):

    A = A.float()
    device = A.device
    m, n = A.shape

    rank = int(rank)
    oversampling = int(oversampling)
    n_iter = int(n_iter)

    l = rank + oversampling
    l = max(1, min(l, m, n))

    Omega = torch.randn(n, l, dtype=torch.float32, device=device)

    Y = A @ Omega
    Q, _ = torch.linalg.qr(Y, mode="reduced")

    for _ in range(n_iter):
        Z = A.T @ Q
        Qz, _ = torch.linalg.qr(Z, mode="reduced")
        Y = A @ Qz
        Q, _ = torch.linalg.qr(Y, mode="reduced")

    B = Q.T @ A
    U_tilde, S, Vh = torch.linalg.svd(B, full_matrices=False)
    U = Q @ U_tilde

    r = min(rank, U.shape[1], S.numel(), Vh.shape[0])
    U = U[:, :r]
    S = S[:r]
    Vh = Vh[:r, :]

    return U, S, Vh


def find_adaptive_k_by_knee(vals, max_rank):
    if vals.numel() < 3:
        return max(1, min(vals.numel(), max_rank))

    n = min(len(vals), max_rank)
    v = vals[:n]
    
    x = torch.linspace(0, 1, n, device=v.device)
    y = (v - v.min()) / (v.max() - v.min() + 1e-8)

    distances = torch.abs(x + y - 1)
    k = torch.argmax(distances).item() + 1

    return max(1, int(k))


def tsallis_gate(energy_map, temp):
    flat_E = energy_map.view(-1)
    sorted_E, _ = torch.sort(flat_E, descending=True)
    limit = max(5, int(flat_E.numel() * 0.2))
    k = find_adaptive_k_by_knee(sorted_E, max_rank=limit)
    top_k = sorted_E[:k]

    log_E = torch.log(top_k + 1e-8)
    alpha_inv = torch.mean(log_E) - log_E[-1]
    xi = alpha_inv  # Hill estimator ξ = mean(log λ_i) - log λ_k

    q_raw = 1.0 / (1.0 + xi + 1e-6)
    q_adaptive = torch.clamp(q_raw, 0.1, 0.8)

    gamma = (1.0 - q_adaptive) / temp
    base = 1.0 - gamma * energy_map
    power = 1.0 / (1.0 - q_adaptive)

    gate = torch.clamp(base, min=0) ** power

    return gate


class MoCLOptimizer:
    def __init__(
        self,
        layer: nn.Linear,
        r1: int,
        r2: int,
        eps_th: float,
        lr: float,
        device: torch.device,
        weight_decay: float = 0.0,
        beta1: float = 0.8,
        beta2: float = 0.999,
        eps: float = 1e-8,
        metabolic_temp: float = 0.001,
    ):
        if not isinstance(layer, nn.Linear):
            raise TypeError("MoCLOptimizer expects an nn.Linear layer.")
        if layer.bias is not None:
            layer.bias.requires_grad = False
        layer.weight.requires_grad = True

        self.layer = layer
        self.device = device
        self.r1 = r1
        self.r2 = r2
        self.eps_th = eps_th
        self.lr = lr
        self.weight_decay = weight_decay

        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps

        self.metabolic_temp = metabolic_temp

        weight = layer.weight
        self.m, self.n = weight.shape
        dtype = weight.dtype

        self.fim_tracker = KFAC_Tracker(
            self.m, self.n, device=self.device
        )
        self.fim_bank = []

        self.momentum_U: Optional[torch.Tensor] = None
        self.momentum_S: Optional[torch.Tensor] = None
        self.momentum_V: Optional[torch.Tensor] = None

        self.exp_avg_sq_U: Optional[torch.Tensor] = None
        self.exp_avg_sq_S: Optional[torch.Tensor] = None
        self.exp_avg_sq_V: Optional[torch.Tensor] = None

        self.adam_t = 0

        self.History_U_G: Optional[torch.Tensor] = None  # (m, r_total)
        self.History_U_A: Optional[torch.Tensor] = None  # (n, r_total)
        self.History_E_G: Optional[torch.Tensor] = None  # (r_total,)
        self.History_E_A: Optional[torch.Tensor] = None  # (r_total,)

    def _reconstruct_momentum(self) -> torch.Tensor:
        if self.momentum_U is None:
            return torch.zeros(
                self.m, self.n,
                device=self.device,
                dtype=self.layer.weight.dtype
            )

        momentum_full = self.momentum_U @ torch.diag(
            self.momentum_S) @ self.momentum_V.T
        return momentum_full

    def _reconstruct_exp_avg_sq(self) -> torch.Tensor:
        if self.exp_avg_sq_U is None:
            return torch.full(
                (self.m, self.n),
                fill_value=1e-10,
                device=self.device,
                dtype=self.layer.weight.dtype
            )

        v_recon = self.exp_avg_sq_U @ torch.diag(
            self.exp_avg_sq_S) @ self.exp_avg_sq_V.T

        neg_mask = v_recon < 0

        if neg_mask.any():
            neg_values = v_recon[neg_mask]
            zeta = neg_values.abs().mean()
            v_safe = torch.clamp(v_recon, min=1e-10)
            v_safe[neg_mask] = zeta.clamp(min=1e-10)
        else:
            v_safe = torch.clamp(v_recon, min=1e-10)

        return v_safe

    def _compress_momentum(self, momentum_full: torch.Tensor) -> None:
        if momentum_full.norm() < 1e-12:
            return

        try:
            oversampling = max(5, min(10, self.r1 // 5))
            target_rank = min(self.r1, min(self.m, self.n) - oversampling - 1)
            target_rank = max(1, target_rank)

            U, S, Vh = rsvd(momentum_full, rank=target_rank,
                            oversampling=oversampling)

            if torch.isnan(S).any() or torch.isinf(S).any():
                logging.warning("RSVD returned NaN/Inf for momentum, skipping")
                return

            self.momentum_U = U.contiguous()
            self.momentum_S = S.contiguous()
            self.momentum_V = Vh.T.contiguous()

        except (torch.linalg.LinAlgError, RuntimeError) as e:
            logging.warning(
                "RSVD failed for momentum: %s, skipping compression", e)

    def _compress_exp_avg_sq(self, v_full: torch.Tensor) -> None:
        if v_full.norm() < 1e-12:
            return

        try:
            oversampling = max(5, min(10, self.r1 // 5))
            target_rank = min(self.r1, min(self.m, self.n) - oversampling - 1)
            target_rank = max(1, target_rank)

            U, S, Vh = rsvd(v_full, rank=target_rank,
                            oversampling=oversampling)

            if torch.isnan(S).any() or torch.isinf(S).any():
                logging.warning(
                    "RSVD returned NaN/Inf for exp_avg_sq, skipping")
                return

            self.exp_avg_sq_U = U.contiguous()
            self.exp_avg_sq_S = S.contiguous()
            self.exp_avg_sq_V = Vh.T.contiguous()

        except (torch.linalg.LinAlgError, RuntimeError) as e:
            logging.warning(
                "RSVD failed for exp_avg_sq: %s, skipping compression", e)

    def _get_adaptive_rank(self, vals: torch.Tensor, energy_th: float, magnitude_th: float, max_rank: int) -> int:
        max_val = vals[0]
        valid_idx_magnitude = vals > (max_val * magnitude_th)

        total_energy = vals.sum()
        cumsum_ratio = torch.cumsum(vals, dim=0) / total_energy
        valid_idx_energy = cumsum_ratio <= energy_th
        r_energy = valid_idx_energy.sum().item()
        if r_energy < len(vals):
            r_energy += 1

        r_magnitude = valid_idx_magnitude.sum().item()
        r_adaptive = min(r_energy, r_magnitude)

        r_final = min(r_adaptive, max_rank)

        return max(r_final, 1)

    def step(self, step_idx: int) -> None:
        if self.layer.weight.grad is None:
            return
        grad = self.layer.weight.grad.detach()
        if grad.numel() == 0:
            return
        grad = grad.to(self.device)

        grad_total = grad + self.weight_decay * self.layer.weight.detach()
        grad_safe = grad_total.clone()

        if self.History_U_G is not None and self.History_U_A is not None:

            U_G, U_A = self.History_U_G, self.History_U_A

            Core_Proj = U_G.T @ (grad_safe @ U_A)

            Energy_Map = torch.outer(self.History_E_G, self.History_E_A)
            max_energy = Energy_Map.max()
            normalized_E = Energy_Map / max_energy

            Final_Gate = tsallis_gate(normalized_E, self.metabolic_temp)
            Final_Gate = torch.clamp(Final_Gate, min=1e-13)

            Blocked_Core = Core_Proj * (1.0 - Final_Gate)

            correction = U_G @ (Blocked_Core @ U_A.T)

            c_norm = correction.norm()
            g_norm = grad_safe.norm() + 1e-10
            if c_norm > g_norm:
                correction = correction * (g_norm / c_norm)

            grad_safe = grad_safe - correction

        m_prev = self._reconstruct_momentum()
        v_prev = self._reconstruct_exp_avg_sq()

        self.adam_t += 1

        m_new = self.beta1 * m_prev + (1 - self.beta1) * grad_safe
        v_new = self.beta2 * v_prev + (1 - self.beta2) * (grad_safe ** 2)

        bias_correction1 = 1.0 - self.beta1 ** self.adam_t
        bias_correction2 = 1.0 - self.beta2 ** self.adam_t

        m_hat = m_new / bias_correction1
        v_hat = v_new / bias_correction2

        denom = torch.sqrt(v_hat) + self.eps
        update_step = m_hat / denom

        if update_step.norm() > 5.0:
            update_step = update_step * (5.0 / update_step.norm())

        with torch.no_grad():
            self.layer.weight -= self.lr * update_step

        self._compress_momentum(m_new)
        self._compress_exp_avg_sq(v_new)

    def update_historical_subspace(self) -> None:
        self.fim_tracker.finalize()
        vals_A, vecs_A = self.fim_tracker.eval_A, self.fim_tracker.QA
        vals_G, vecs_G = self.fim_tracker.eval_G, self.fim_tracker.QG

        max_r_G = self.r2
        max_r_A = self.r2

        r2_G = self._get_adaptive_rank(
            vals_G,
            energy_th=self.eps_th,
            magnitude_th=1e-2,
            max_rank=max_r_G
        )

        r2_A = self._get_adaptive_rank(
            vals_A,
            energy_th=self.eps_th,
            magnitude_th=1e-2,
            max_rank=max_r_A
        )

        cur_U_G = vecs_G[:, :r2_G]
        cur_E_G = vals_G[:r2_G]

        cur_U_A = vecs_A[:, :r2_A]
        cur_E_A = vals_A[:r2_A]

        if self.History_U_G is None:
            self.History_U_G = cur_U_G
            self.History_E_G = cur_E_G
        else:
            proj = self.History_U_G @ (self.History_U_G.T @ cur_U_G)
            resid = cur_U_G - proj
            W_old = self.History_U_G * torch.sqrt(self.History_E_G).view(1, -1)
            W_new = cur_U_G * torch.sqrt(cur_E_G).view(1, -1)

            M_combined = torch.cat([W_old, W_new], dim=1)

            U, S, _ = torch.linalg.svd(M_combined, full_matrices=False)

            total_dim = U.shape[0]
            safety_threshold = int(total_dim * 0.98)

            if U.shape[1] > safety_threshold:
                keep_rank = int(total_dim * 0.98)

                U = U[:, :keep_rank]
                S = S[:keep_rank]

            self.History_U_G = U
            self.History_E_G = S ** 2

        if self.History_U_A is None:
            self.History_U_A = cur_U_A
            self.History_E_A = cur_E_A
        else:
            W_old = self.History_U_A * torch.sqrt(self.History_E_A).view(1, -1)
            W_new = cur_U_A * torch.sqrt(cur_E_A).view(1, -1)
            M_combined = torch.cat([W_old, W_new], dim=1)

            U, S, _ = torch.linalg.svd(M_combined, full_matrices=False)

            total_dim = U.shape[0]
            safety_threshold = int(total_dim * 0.98)

            if U.shape[1] > safety_threshold:
                keep_rank = int(total_dim * 0.98)

                U = U[:, :keep_rank]
                S = S[:keep_rank]

            self.History_U_A = U
            self.History_E_A = S ** 2

        self.fim_tracker.reset()
        self.adam_t = 0
        self.momentum_U = None
        self.momentum_S = None
        self.momentum_V = None
        self.exp_avg_sq_U = None
        self.exp_avg_sq_S = None
        self.exp_avg_sq_V = None


class KFAC_Tracker:
    def __init__(self, m: int, n: int, device=None):
        self.m = m
        self.n = n
        self.device = device

        self.cov_A = torch.zeros(n, n, device=device)
        self.cov_G = torch.zeros(m, m, device=device)

        self.mode = 'sleep'
        self.step_count = 0

        self.eval_A = None
        self.eval_G = None
        self.QA = None
        self.QG = None

    def switch_to_accumulation(self):
        self.cov_A.zero_()
        self.cov_G.zero_()
        self.step_count = 0
        self.mode = 'accumulation'

    def _eigendecompose(self):
        gamma = 1e-3

        avg_energy_A = self.cov_A.trace() / self.n
        cov_A_shrunk = (1 - gamma) * self.cov_A + gamma * \
            avg_energy_A * torch.eye(self.n, device=self.device)
        avg_energy_G = self.cov_G.trace() / self.m
        cov_G_shrunk = (1 - gamma) * self.cov_G + gamma * \
            avg_energy_G * torch.eye(self.m, device=self.device)

        vals_A, vecs_A = torch.linalg.eigh(cov_A_shrunk)
        vals_G, vecs_G = torch.linalg.eigh(cov_G_shrunk)
        vals_A = vals_A.flip(0)
        vecs_A = vecs_A.flip(1)
        vals_G = vals_G.flip(0)
        vecs_G = vecs_G.flip(1)
        vals_A = torch.clamp(vals_A, min=0)
        vals_G = torch.clamp(vals_G, min=0)
        self.eval_A, self.eval_G = vals_A, vals_G
        self.QA, self.QG = vecs_A, vecs_G

    def finalize(self):
        if self.step_count > 0:
            self.cov_A.div_(self.step_count)
            self.cov_G.div_(self.step_count)
        self._eigendecompose()

    def update(self, x: torch.Tensor, g: torch.Tensor):
        if self.mode == 'sleep':
            return
        elif self.mode == 'accumulation':
            if x.dim() > 2:
                x = x.view(-1, x.size(-1))
            if g.dim() > 2:
                g = g.view(-1, g.size(-1))
            batch_size = x.size(0)

            curr_A = x.T @ x
            curr_G = g.T @ g

            self.cov_A.add_(curr_A)
            self.cov_G.add_(curr_G)
            self.step_count += batch_size

    def reset(self):
        self.mode = 'sleep'
        self.cov_A.zero_()
        self.cov_G.zero_()
        self.eval_A = None
        self.eval_G = None
        self.QA = None
        self.QG = None
        self.step_count = 0


class Learner(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
        self._network = IncrementalNet(args, pretrained=True)

        self.epochs = args.get("epochs", 40)
        self.batch_size = args.get("batch_size", 128)
        self.temperature = args.get("temperature", 1.0)
        self.fc_lr = args.get("fc_lr", args.get("lr", 0.001))
        self.backbone_lr = args.get("backbone_lr", args.get("lr", 0.001))
        self.fc_weight_decay = args.get("fc_weight_decay", 0.0)
        self.attn_weight_decay = args.get("attn_weight_decay", 0.0)
        self.mlp_weight_decay = args.get("mlp_weight_decay", 0.0)
        self.min_lr = args.get("min_lr", args.get("min_lr", 0.0))
        self.mocl_scheduler = args.get("mocl_scheduler", "cosine")
        self.mocl_milestones = args.get("mocl_milestones", [])
        self.mocl_gamma = args.get("mocl_gamma", 0.1)
        self.mocl_r1 = args.get("mocl_r1", args.get("r1", 50))
        self.mocl_r2 = args.get("mocl_r2", args.get("r2", 120))
        self.mocl_eps_th = args.get("mocl_eps_th", args.get("eps_th", 0.98))
        self.num_workers = args.get("num_workers", num_workers)
        self.metabolic_temp = args.get("metabolic_temp", 2.0)
        self.cal_epochs = args.get("cal_epochs", 5)

        self._head_weight_cache: List[torch.Tensor] = []
        self._head_bias_cache: List[Optional[torch.Tensor]] = []

        if len(self._multiple_gpus) > 1:
            logging.warning(
                "MoCL currently operates on a single GPU; defaulting to the first device."
            )

        self._freeze_backbone_except_attention()
        self.mocl_optimizers: List[MoCLOptimizer] = []
        self._build_mocl_optimizers()
        self._register_hooks()

    def _register_hooks(self) -> None:
        mod_to_opt = {opt.layer: opt for opt in self.mocl_optimizers}

        def fwd_hook(module, inp, out):
            if module in mod_to_opt:
                tracker = mod_to_opt[module].fim_tracker
                if tracker.mode == 'accumulation':
                    mod_to_opt[module].fim_tracker.saved_x = inp[0].detach()

        def bwd_hook(module, grad_in, grad_out):
            if module in mod_to_opt:
                opt = mod_to_opt[module]
                tracker = opt.fim_tracker

                if hasattr(tracker, 'saved_x') and tracker.saved_x is not None:
                    x = tracker.saved_x
                    g = grad_out[0].detach()

                    tracker.update(x, g)
                    tracker.saved_x = None

        for module in mod_to_opt.keys():
            module.register_forward_hook(fwd_hook)
            module.register_full_backward_hook(bwd_hook)

    def _freeze_backbone_except_attention(self) -> None:
        for param in self._network.backbone.parameters():
            param.requires_grad = False

    def _build_mocl_optimizers(self) -> None:
        device = self._device
        for name, module in self._network.backbone.named_modules():
            if isinstance(module, nn.Linear) and ("attn.proj" in name):
                module.weight.requires_grad = True
                module.name = name
                optimizer = MoCLOptimizer(
                    module,
                    r1=self.mocl_r1,
                    r2=self.mocl_r2,
                    eps_th=self.mocl_eps_th,
                    lr=self.backbone_lr,
                    device=device,
                    weight_decay=self.attn_weight_decay,
                    metabolic_temp=self.metabolic_temp,
                )
                self.mocl_optimizers.append(optimizer)
        if not self.mocl_optimizers:
            logging.warning(
                "MoCL optimizer list is empty; no attention projection layers were found.")
        else:
            logging.info(
                "Initialized %d MoCL optimizers for attention projection layers.",
                len(self.mocl_optimizers),
            )

    def after_task(self):
        self._known_classes = self._total_classes

    def incremental_train(self, data_manager):
        self._cur_task += 1
        current_temp = self.metabolic_temp / math.sqrt(max(1, self._cur_task))
        for opt in self.mocl_optimizers:
            opt.metabolic_temp = current_temp
        task_size = data_manager.get_task_size(self._cur_task)
        self._total_classes = self._known_classes + task_size

        self._network.update_fc(self._total_classes)
        self._apply_cached_head()
        logging.info("Learning on classes %d-%d",
                     self._known_classes, self._total_classes)

        train_dataset = data_manager.get_dataset(
            np.arange(self._known_classes, self._total_classes),
            source="train",
            mode="train",
        )
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

        test_dataset = data_manager.get_dataset(
            np.arange(0, self._total_classes),
            source="test",
            mode="test",
        )
        self.test_loader = DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

        self._train()
        self._cache_current_head_slice()
        self._apply_cached_head()

    def _train(self) -> None:
        device = self._device
        self._network.to(device)
        self._network.train()
        self._network.backbone.train()
        self._network.fc.train()

        criterion = nn.CrossEntropyLoss()

        for mocl_optim in self.mocl_optimizers:
            mocl_optim.fim_tracker.reset()

        optimizer = optim.AdamW(
            self._network.fc.parameters(),
            lr=self.fc_lr,
            weight_decay=0.0,
        )

        lr_ratio = self.backbone_lr / self.fc_lr if self.fc_lr > 0 else 1.0

        scheduler = self._build_scheduler(optimizer)

        total_batches = max(len(self.train_loader), 1)

        start_cal_epoch = max(0, self.epochs - self.cal_epochs)

        for epoch in range(self.epochs):
            if epoch >= start_cal_epoch:
                for opt in self.mocl_optimizers:
                    if opt.fim_tracker.mode != 'accumulation':
                        opt.fim_tracker.switch_to_accumulation()
            else:
                for opt in self.mocl_optimizers:
                    opt.fim_tracker.mode = 'sleep'

            epoch_loss = 0.0
            correct = 0
            total = 0
            prog_bar = tqdm(
                enumerate(self.train_loader),
                total=total_batches,
                leave=False,
                desc=f"Task {self._cur_task} | Epoch {epoch + 1}/{self.epochs}",
            )

            for step_in_epoch, (_, inputs, targets) in prog_bar:
                global_step = epoch * total_batches + step_in_epoch
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = self._network(inputs)["logits"]
                if self._cur_task == 0:
                    logits = outputs[:, : self._total_classes]
                    labels = targets
                else:
                    logits = outputs[:,
                                     self._known_classes: self._total_classes]
                    labels = targets - self._known_classes

                loss = criterion(logits / self.temperature, labels)

                optimizer.zero_grad()
                loss.backward()

                if self._known_classes > 0:
                    self._network.fc.weight.grad[:self._known_classes] = 0.0
                    if self._network.fc.bias is not None:
                        self._network.fc.bias.grad[:self._known_classes] = 0.0

                for mocl_optim in self.mocl_optimizers:
                    mocl_optim.step(global_step)

                optimizer.step()

                if self.fc_weight_decay > 0.0:
                    current_lr = optimizer.param_groups[0]["lr"]

                    with torch.no_grad():
                        start_idx = self._known_classes
                        self._network.fc.weight.data[start_idx:].mul_(
                            1.0 - current_lr * self.fc_weight_decay)

                        if self._network.fc.bias is not None:
                            self._network.fc.bias.data[start_idx:].mul_(
                                1.0 - current_lr * self.fc_weight_decay
                            )

                self._network.zero_grad()

                batch_size = inputs.size(0)
                epoch_loss += loss.item() * batch_size
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(labels).sum().item()
                total += batch_size

                prog_bar.set_postfix(
                    loss=epoch_loss / total if total else 0.0,
                    acc=100.0 * correct / total if total else 0.0,
                )

            epoch_loss /= max(total, 1)
            train_acc = 100.0 * correct / max(total, 1)
            if scheduler is not None:
                scheduler.step()

            current_fc_lr = optimizer.param_groups[0]["lr"]
            for mocl_optim in self.mocl_optimizers:
                mocl_optim.lr = current_fc_lr * lr_ratio

            logging.info(
                "Task %d Epoch %d/%d => loss %.4f | train_acc %.2f | lr_fc %.5f | lr_backbone %.5f",
                self._cur_task,
                epoch + 1,
                self.epochs,
                epoch_loss,
                train_acc,
                current_fc_lr,
                current_fc_lr * lr_ratio,
            )

        for mocl_optim in self.mocl_optimizers:
            mocl_optim.update_historical_subspace()

        self._network.backbone.eval()

    def _build_scheduler(self, optimizer: optim.Optimizer):
        scheduler = None
        if self.mocl_scheduler.lower() in {"none", "constant"}:
            return None
        if self.mocl_scheduler.lower() == "cosine":
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=self.epochs,
                eta_min=self.min_lr,
            )
        elif self.mocl_scheduler.lower() == "steplr":
            milestones = self.mocl_milestones if self.mocl_milestones else [
                max(self.epochs // 2, 1)]
            scheduler = optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=milestones,
                gamma=self.mocl_gamma,
            )
        else:
            logging.warning(
                "Unknown scheduler %s for MoCL; falling back to cosine annealing.",
                self.mocl_scheduler,
            )
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=self.epochs,
                eta_min=self.min_lr,
            )
        return scheduler

    def _cache_current_head_slice(self) -> None:
        start = self._known_classes
        end = self._total_classes
        if start >= end:
            return
        weight_slice = self._network.fc.weight.data[start:end].detach(
        ).cpu().clone()
        bias_slice = None
        if self._network.fc.bias is not None:
            bias_slice = self._network.fc.bias.data[start:end].detach(
            ).cpu().clone()

        if len(self._head_weight_cache) <= self._cur_task:
            self._head_weight_cache.append(weight_slice)
            self._head_bias_cache.append(bias_slice)
        else:
            self._head_weight_cache[self._cur_task] = weight_slice
            self._head_bias_cache[self._cur_task] = bias_slice

    def _cached_class_count(self) -> int:
        if not self._head_weight_cache:
            return 0
        return sum(weight.size(0) for weight in self._head_weight_cache)

    def _apply_cached_head(self) -> None:
        cached_classes = self._cached_class_count()
        if cached_classes == 0:
            return

        device = self._network.fc.weight.device
        dtype = self._network.fc.weight.dtype
        weight = torch.cat(
            [w.to(device=device, dtype=dtype)
             for w in self._head_weight_cache],
            dim=0,
        )
        with torch.no_grad():
            self._network.fc.weight.data[:cached_classes] = weight

        if self._network.fc.bias is None:
            return

        bias_tensors = [b for b in self._head_bias_cache if b is not None]
        if len(bias_tensors) != len(self._head_bias_cache):
            return
        bias = torch.cat(
            [b.to(device=device, dtype=self._network.fc.bias.dtype)
             for b in bias_tensors],
            dim=0,
        )
        with torch.no_grad():
            self._network.fc.bias.data[:cached_classes] = bias

    def eval_task(self):
        self._apply_cached_head()
        return super().eval_task()
