# Import the necessary libraries
import torch
import torch.nn as nn
import accelerate
from accelerate.optimizer import AcceleratedOptimizer
from transformers import Trainer
from trl.trainer.sft_trainer import SFTTrainer  
from trl.trainer.utils import entropy_from_logits
from torch.utils.data import Dataset, SequentialSampler
from transformers import EvalPrediction, TrainerCallback
from typing import Optional, Any, Callable, List
from collections import defaultdict
from itertools import repeat
from peft.peft_model import PeftModelForCausalLM
from peft import PeftType
from resnet20 import ResNet20ForCIFAR
from torch.nn.parallel import DistributedDataParallel
import torch
import os
from copy import deepcopy

TRAIN_LLM = os.getenv("TRAIN_LLM", "false").lower() == "true"
BaseTrainer = SFTTrainer if TRAIN_LLM else Trainer

class CustomOptimizer(torch.optim.Optimizer):
    def __init__(self, model, defaults, optimizer: Optional[torch.optim.Optimizer] = None, alternate_gpu: bool = True):
        assert not isinstance(optimizer, torch.optim.LBFGS), "LBFGS optimizer is not supported."
        params = list(model.parameters())
        self.logs = defaultdict(list)  # For logging various metrics during optimization
        super().__init__(params, defaults)
        if optimizer is None:
            # Use SGD with no momentum as default optimizer
            optimizer = torch.optim.SGD(self.param_groups, lr=defaults.get('lr', 1e-3), momentum=0.0)
        self.optimizer = optimizer
        self.alternate_gpu = alternate_gpu

        self.zero_grad = lambda set_to_none=False: self.optimizer.zero_grad(set_to_none=set_to_none)
        self.state_dict = lambda : self.optimizer.state_dict()
        self.load_state_dict = lambda state_dict: self.optimizer.load_state_dict(state_dict)
        self.add_param_group = lambda param_group: self.optimizer.add_param_group(param_group)

    def preprocess_params(self) -> dict:
        return {}

    def _step(self) -> None:
        if not self.alternate_gpu:
            return
        processed_params = self.preprocess_params()
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.grad = self.get_grad(p, **(group | processed_params))

    def step(self, closure = None) -> None:
        return self.optimizer.step(closure)

    def set_accelerator(self, accelerator: accelerate.Accelerator):
        self.accelerator = accelerator
        # Stats are torch tensors instead of simply using lists to allow gathering across GPUs (I think this is the case?)
        self.stats = torch.zeros(4, dtype=torch.long, device=self.accelerator.device)  # [branch_1, branch_2, failed_feasibility, nan_in_delta]

    def collate_gr_gf(self, p):
        local_grad = p.grad.contiguous()
        all_gpu_grads_concat = self.accelerator.gather(local_grad)
        
        # We split into (num_processes / 2) groups, each containing 2 gradients (retain and forget)
        grads_per_gpu = all_gpu_grads_concat.reshape(self.accelerator.num_processes//2, 2, *local_grad.shape)  # Assuming each GPU has retain and forget grads

        grads_per_gpu = grads_per_gpu.mean(dim=0)  # Average over GPUs
        gr = grads_per_gpu[0]
        gf = grads_per_gpu[1]
        return gr, gf

    def get_grad(self, p, *args, **kwargs):
        local_grad = p.grad.contiguous()
        grad = self.accelerator.gather(local_grad).reshape(-1, *local_grad.shape).mean(dim=0)
        return grad

# MyOptimizer
class MyOptimizer(CustomOptimizer):
    # Custom optimizer implementing the min-R update rule for unlearning
    # Replaces standard optimizer step (which collapses all gradients from all GPUs)
    # with a custom step that separately gathers retain and forget gradients from alternate GPUs

    def __init__(self, model, optimizer=None, lr=None, R=1e-3, Q=0., dual_update=False, stop_on_failed_feasibility=False, full_grad=False, 
                 distribute_B_softmax_dp=False, distribute_B_gr_gf_norm=False, distribute_B_norm=False, debug=False):
        assert Q >= 0, "Q must be non-negative."
        if lr is not None:
            print("Lr is set, using lr to override R, R not used directly.")
            self.use_lr = True
        else:
            self.use_lr = False
            lr = 1.0  # Dummy value, not used
        self.model = model
        self.dual_update = dual_update
        self.stop_on_failed_feasibility = stop_on_failed_feasibility
        self.full_grad = full_grad
        self.distribute_B_softmax_dp = distribute_B_softmax_dp
        self.distribute_B_gr_gf_norm = distribute_B_gr_gf_norm
        self.distribute_B_norm = distribute_B_norm
        super().__init__(model, defaults={'lr': lr, 'R': R, 'Q': Q}, optimizer=optimizer, alternate_gpu=True)
        self._lr = torch.tensor(lr, device=model.device)
        self._R = torch.tensor(R, device=model.device) if R is not None else torch.tensor(-1.0, device=model.device)
        self._Q = torch.tensor(Q, device=model.device)

    def min_r_update(self, gr: torch.Tensor, gf: torch.Tensor, lr: torch.Tensor, R: torch.Tensor, Q: torch.Tensor, eps=1e-12, full_gr_norm: Optional[torch.Tensor]=None, full_gf_norm: Optional[torch.Tensor]=None, full_gdp: Optional[torch.Tensor]=None):
        if full_gr_norm is None:
            gr_norm = torch.linalg.norm(gr)
        else:
            gr_norm = full_gr_norm
        if full_gf_norm is None:
            gf_norm = torch.linalg.norm(gf)
            gf_norm_sq = torch.sum(gf * gf)
        else:
            gf_norm = full_gf_norm
            gf_norm_sq = gf_norm * gf_norm
        if full_gdp is None:
            gdp = torch.dot(gr.flatten(), gf.flatten())
        else:
            gdp = full_gdp
        if self.use_lr:
            R = lr * gr_norm  # Scale R by lr
        if self.distribute_B_softmax_dp:
            Q = Q * torch.exp(-gdp)
        if self.distribute_B_gr_gf_norm:
            Q = Q * (gr_norm * gf_norm)
        threshold = -Q * gr_norm / R
        discriminant = R * R - Q * Q / gf_norm_sq
        gr_n = gr - (gdp / gf_norm_sq) * gf
        gr_n_norm = torch.linalg.norm(gr_n)
        self.logs["R"].append(R.to(device='cpu', non_blocking=True))
        self.logs["B"].append(Q.to(device='cpu', non_blocking=True))
        self.logs["gf_norm"].append(gf_norm.to(device='cpu', non_blocking=True))
        self.logs["gr_norm"].append(gr_norm.to(device='cpu', non_blocking=True))
        self.logs["gr_n_norm"].append(gr_n_norm.to(device='cpu', non_blocking=True))
        self.logs["gdp"].append(gdp.to(device='cpu', non_blocking=True))
        self.logs["angle"].append((torch.rad2deg(torch.acos(gdp / (gr_norm * gf_norm)))).to(device='cpu', non_blocking=True))
        self.logs["R^2"].append((R*R).to(device='cpu', non_blocking=True))
        Q_over_gf_norm = Q / gf_norm
        self.logs["B^2/gf_norm^2"].append((Q_over_gf_norm*Q_over_gf_norm).to(device='cpu', non_blocking=True))
        self.logs["B/gf_norm"].append((Q_over_gf_norm).to(device='cpu', non_blocking=True))
        self.logs["feasibility"].append((Q_over_gf_norm / R).to(device='cpu', non_blocking=True))
        self.logs["branch_threshold"].append((threshold).to(device='cpu', non_blocking=True))
        self.logs["branch_2_1_eq_lr"].append((Q_over_gf_norm/gf_norm).to(device='cpu', non_blocking=True))
        if discriminant >= 0:
            sqrt_discriminant = torch.sqrt(discriminant)
            self.logs["sqrt_discriminant"].append(sqrt_discriminant.to(device='cpu', non_blocking=True))
            self.logs["branch_2_2_eq_lr"].append((sqrt_discriminant/gr_n_norm).to(device='cpu', non_blocking=True))
        if gf_norm_sq < eps:
            if Q == 0:
                self.stats[0] += 1
                update = -(R / gr_norm) * gr    # Gradient descent
                return update
            if not self.stop_on_failed_feasibility:
                self.stats[1] += 1
                return torch.zeros_like(gr)
            else:
                self.stats[2] += 1
                return torch.full_like(gr, float('nan'))    # Indicate failed feasibility
        update = (Q / gf_norm_sq) * gf    # Gradient ascent
        if discriminant < 0:
            self.stats[2] += 1
            if self.stop_on_failed_feasibility:
                return torch.full_like(gr, float('nan'))    # Indicate failed feasibility
            else:
                return update
        if gdp <= threshold:
            self.stats[0] += 1
            update = -(R / gr_norm) * gr
            return update
        else:
            self.stats[1] += 1
            if gr_n_norm < eps:
                return update  # gr and gf are colinear
            update_n = -(R if Q == 0 else torch.sqrt(discriminant)) * gr_n / gr_n_norm
            update = update + update_n
            return update

    def get_grad(self, p, lr=torch.tensor(1.0, dtype=torch.float), R=torch.tensor(1.0, dtype=torch.float), Q=torch.tensor(.0, dtype=torch.float), gr_norm=None, gf_norm=None, gdp=None, **kwargs):
        gr, gf = self.collate_gr_gf(p)
        if self.dual_update:
            delta = self.min_r_update(-gf, -gr, lr, R, Q, full_gr_norm=gf_norm, full_gf_norm=gr_norm, full_gdp=gdp)
        else:
            delta = self.min_r_update(gr, gf, lr, R, Q, full_gr_norm=gr_norm, full_gf_norm=gf_norm, full_gdp=gdp)
        equivalent_grad = - delta / lr
        return equivalent_grad

    def preprocess_params(self) -> dict:
        Q = self._Q
        gr_norm = None
        gf_norm = None
        gdp = None
        if self.full_grad:
            # Compute full gradient norms and gdp
            gr_norm_sqr = torch.tensor(0.0, device=self.accelerator.device)
            gf_norm_sqr = torch.tensor(0.0, device=self.accelerator.device)
            gdp = torch.tensor(0.0, device=self.accelerator.device)
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue
                    gr, gf = self.collate_gr_gf(p)
                    curr_gr_norm_sqr = torch.sum(gr * gr)
                    curr_gf_norm_sqr = torch.sum(gf * gf)
                    gr_norm_sqr += curr_gr_norm_sqr
                    gf_norm_sqr += curr_gf_norm_sqr
                    gdp += torch.dot(gr.flatten(), gf.flatten())
            gr_norm = torch.sqrt(gr_norm_sqr)
            gf_norm = torch.sqrt(gf_norm_sqr)
            self.logs["full_gr_norm"].append(gr_norm.to(device='cpu', non_blocking=True))
            self.logs["full_gf_norm"].append(gf_norm.to(device='cpu', non_blocking=True))
            self.logs["full_gdp"].append(gdp.to(device='cpu', non_blocking=True))
        elif self.distribute_B_gr_gf_norm or self.distribute_B_softmax_dp:
            if self.distribute_B_norm:
                # Compute full gradient norms and gdp
                gr_norm_gf_norm = torch.tensor(0.0, device=self.accelerator.device)
                exp_dp = torch.tensor(0.0, device=self.accelerator.device)
                for group in self.param_groups:
                    for p in group['params']:
                        if p.grad is None:
                            continue
                        gr, gf = self.collate_gr_gf(p)
                        gr_norm_gf_norm += torch.linalg.norm(gr) * torch.linalg.norm(gf)
                        exp_dp += torch.exp(-torch.dot(gr.flatten(), gf.flatten()))
                if self.distribute_B_softmax_dp:
                    self.logs["exp_dp"].append(exp_dp.to(device='cpu', non_blocking=True))
                    Q = Q / exp_dp
                elif self.distribute_B_gr_gf_norm:
                    self.logs["gr_norm_gf_norm"].append(gr_norm_gf_norm.to(device='cpu', non_blocking=True))
                    Q = Q / gr_norm_gf_norm
        else:
            Q /= sum(len(group['params']) for group in self.param_groups)  # Average Q over all params
        return {'Q': Q, 'gr_norm': gr_norm, 'gf_norm': gf_norm, 'gdp': gdp}

class CustomOptimizerCallback(TrainerCallback):
    # Replaces the default optimizer step with CustomOptimizer's _step
    def __init__(self, additional_callbacks: Optional[List[Callable]] = None):
        super().__init__()
        self.optimizer: Optional[CustomOptimizer] = None
        self.additional_callbacks = additional_callbacks if additional_callbacks is not None else []

    def on_pre_optimizer_step(self, args, state, control, **kwargs):
        if self.optimizer is None:
            return
        # Replace the default optimizer step with CustomOptimizer's _step
        try:
            self.optimizer._step()
        except ValueError as e:
            print(f"⚠️ Stopping training due to error in optimizer step: {e}")
            # Stop training on NaN in optimizer
            control.should_log = True
            control.should_training_stop = True # Stop training on NaN in optimizer
        for callback in self.additional_callbacks:
            callback(args, state, control, **kwargs)

class CustomTrainer(BaseTrainer):
    def __init__(self, alternate_gpu=True, debug=False, *args, **kwargs):
        callbacks = kwargs.pop("callbacks", [])
        self.custom_optimizer_callback = CustomOptimizerCallback(additional_callbacks=[self.my_log])
        callbacks.append(self.custom_optimizer_callback)
        self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
        super().__init__(*args, compute_metrics=self._compute_metrics, callbacks=callbacks, **kwargs)
        if isinstance(self.optimizer, CustomOptimizer):
            self.custom_optimizer_callback.optimizer = self.optimizer
        self.alternate_gpu = alternate_gpu
        self.debug = debug

    def update_frozen_model(self):
        """Safely unwraps and copies the model using the Trainer's internal accelerator."""
        accelerator = self.accelerator

        unwrapped_model = accelerator.unwrap_model(self.model)

        frozen_model = deepcopy(unwrapped_model)
        frozen_model.eval()
        frozen_model.requires_grad_(False)
        self.model.frozen_model = [frozen_model]

    def set_optimizer(self, optimizer: Optional[torch.optim.Optimizer]) -> None:
        self.optimizer = optimizer
        if isinstance(optimizer, CustomOptimizer):
            self.custom_optimizer_callback.optimizer = optimizer

    def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
        torch.cuda.synchronize()
        mode = "train" if self.model.training else "eval"
        for k, v in self._metrics[mode].items():
            if any(isinstance(i, torch.Tensor) for i in v):
                self._metrics[mode][k] = [i.item() if isinstance(i, torch.Tensor) else i for i in v]
        metrics = {key: sum(val) / len(val) if len(val) > 0 else 0.0 for key, val in self._metrics[mode].items()}  # average the metrics

        # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
        # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
        if mode == "eval":
            metrics = {f"eval_{key}": val for key, val in metrics.items()}

        logs.update(metrics)
        Trainer.log(self, logs, start_time)
        self._metrics[mode].clear()

    def my_log(self, *args, **kwargs):
        optimizer = self.optimizer
        if optimizer is not None:
            torch.cuda.synchronize()
            if isinstance(optimizer, AcceleratedOptimizer):
                optimizer = optimizer.optimizer
            mode = "train" if self.model.training else "eval"
            for k, v in optimizer.logs.items():
                if len(v) > 0:
                    self._metrics[mode][k].extend(v)
            optimizer.logs.clear()

    def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
        # Use sequential sampler to ensure consistent pairing of retain and forget samples
        return SequentialSampler(train_dataset)
    
    def _compute_metrics(self, pred: EvalPrediction):
        # Override so that metrics on different eval subsets will have different names
        # SFTTrainer's compute_metrics assumes a single eval set, and only the last eval set's metrics are returned
        # Here, we use the parent Trainer's compute_metrics to avoid that issue
        metrics = {}
        for k, v in self._metrics.get("eval", {}).items():
            metrics[k] = sum(v) / len(v) if len(v) > 0 else 0.0
            if isinstance(metrics[k], torch.Tensor):
                metrics[k] = metrics[k].item()
        self._metrics["eval"].clear()
        return metrics
    
    def gather_alternate(self, tensor: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
        tensor = self.accelerator.gather_for_metrics(tensor).reshape(-1, 2, *tensor.shape[1:])
        if reduction == "mean":
            return tensor.mean(dim=0)
        elif reduction == "sum":
            return tensor.sum(dim=0)
        else:
            raise ValueError(f"Unsupported reduction method: {reduction}")

    def compute_loss(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        return_outputs: bool = False,
        num_items_in_batch: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, Any]:
        """
        Compute training loss and additionally compute token accuracies
        Adapted from SFTTrainer's compute_loss
        """
        mode = "train" if self.model.training else "eval"
        if not self.model.training or not self.alternate_gpu:
            # In eval mode, assume that the datasets are properly separated into separate retain and forget sets
            (loss, outputs) = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)
            if not TRAIN_LLM:
                logits = outputs.logits
                labels = inputs["labels"]
                predictions = logits.argmax(dim=-1)
                correct = (predictions == labels).sum().item()
                total = labels.numel()
                accuracy = correct / total if total > 0 else 0.0
                self._metrics[mode]["accuracy"].append(accuracy)
                entropy = entropy_from_logits(outputs.logits)
                self._metrics[mode]["entropy"].append(entropy.mean().item())
            if return_outputs:
                return (loss, outputs)
            return loss

        # Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used.
        # This can be removed when this issue is fixed.
        # When using CP or SP, labels are pre-shifted, we must use shift_labels instead.
        labels = inputs["labels"] if "shift_labels" not in inputs else None

        # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
        if TRAIN_LLM:
            inputs["use_cache"] = False
        # Request token accuracy from Liger kernel and set token scaling if using DFT loss
        if self.args.use_liger_kernel:
            # inputs["return_token_accuracy"] = True
            # inputs["use_token_scaling"] = self.args.loss_type == "dft"
            raise NotImplementedError("Liger kernel not implemented in this version.")

        with self.accelerator.no_sync(model):  # Gather gradients from different GPUs separately
            (loss, outputs) = Trainer.compute_loss(self,  # Call SFTTrainer's parent's compute_loss
                model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
            )

        # Compute entropy
        if not self.args.use_liger_kernel:  # liger doesn't return logits
            with torch.no_grad():
                per_token_entropy = entropy_from_logits(outputs.logits)
                # When using Prompt Tuning, skip the virtual tokens in logits before entropy computation, since they
                # do not correspond to actual input tokens.
                if (TRAIN_LLM and
                    self.num_virtual_tokens > 0
                    and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING
                ):
                    per_token_entropy = per_token_entropy[:, self.num_virtual_tokens :]
                if "attention_mask" in inputs:
                    attention_mask = inputs["attention_mask"]
                    entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum()
                elif "position_ids" in inputs:
                    entropy = torch.mean(per_token_entropy)
                elif not TRAIN_LLM:
                    entropy = torch.mean(per_token_entropy)
                else:
                    raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.")
                entropy = self.gather_alternate(entropy)
                self._metrics[mode]["retain_entropy"].append(entropy[0].item())
                self._metrics[mode]["forget_entropy"].append(entropy[1].item())


        if mode == "train":
            # When using padding-free, the attention_mask is not present in the inputs, instead we have cu_seq_lens_q,
            # cu_seq_lens_k, and max_length_k, max_length_q and position_ids.
            if "attention_mask" in inputs:
                num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
            elif "position_ids" in inputs:
                local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device)
                num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item()
            elif not TRAIN_LLM:
                num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["labels"].ne(-100).sum()).sum().item()
            else:
                raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.")
            if TRAIN_LLM:
                self._total_train_tokens += num_tokens_in_batch
        if TRAIN_LLM:
            self._metrics[mode]["num_tokens"] = [self._total_train_tokens]

        if self.args.use_liger_kernel:
            # token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item()
            # self._metrics[mode]["mean_token_accuracy"].append(token_accuracy)
            raise NotImplementedError("Liger kernel token accuracy logging not implemented yet.")
        if not TRAIN_LLM:
            predictions = outputs.logits.argmax(dim=-1)
            mask = labels != -100
            correct_predictions = (predictions == labels) & mask
            total_tokens = mask.sum()
            correct_tokens = correct_predictions.sum()

            correct_tokens = self.gather_alternate(correct_tokens, reduction="sum")
            total_tokens = self.gather_alternate(total_tokens, reduction="sum")
            r_accuracy = (correct_tokens[0] / total_tokens[0]).item() if total_tokens[0] > 0 else 0.0
            f_accuracy = (correct_tokens[1] / total_tokens[1]).item() if total_tokens[1] > 0 else 0.0
            self._metrics[mode]["retain_token_accuracy"].append(r_accuracy)
            self._metrics[mode]["forget_token_accuracy"].append(f_accuracy)
        else:
            # Compute accuracy from logits using argmax (traditional method)
            with torch.no_grad():
                if "shift_labels" in inputs:
                    # When using CP or SP, labels are pre-shifted. We must use these (and cannot manually shift) because:
                    # - The first discarded token from inputs["labels"] actually belongs to process n-1
                    # - The last logits require the label from process n+1
                    shift_logits = outputs.logits.contiguous()
                    shift_labels = inputs["shift_labels"]
                else:
                    shift_logits = outputs.logits[..., :-1, :].contiguous()
                    shift_labels = labels[..., 1:].contiguous()

                # Prompt Tuning and P-Tuning output logits for virtual tokens but Prefix-Tuning does not.
                if (
                    self.num_virtual_tokens > 0
                    and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING
                ):
                    shift_logits = shift_logits[:, self.num_virtual_tokens :, :]

                # Get predictions
                predictions = shift_logits.argmax(dim=-1)

                # Create mask for non-padding tokens (assuming ignore_index is -100)
                mask = shift_labels != -100

                # Calculate accuracy only on non-padding tokens
                correct_predictions = (predictions == shift_labels) & mask
                total_tokens = mask.sum()
                correct_tokens = correct_predictions.sum()

                correct_tokens = self.gather_alternate(correct_tokens, reduction="sum")
                total_tokens = self.gather_alternate(total_tokens, reduction="sum")
                r_accuracy = (correct_tokens[0] / total_tokens[0]).item() if total_tokens[0] > 0 else 0.0
                f_accuracy = (correct_tokens[1] / total_tokens[1]).item() if total_tokens[1] > 0 else 0.0
                self._metrics[mode]["retain_token_accuracy"].append(r_accuracy)
                self._metrics[mode]["forget_token_accuracy"].append(f_accuracy)

        # Log auxiliary loss if enabled (applies to both Liger and non-Liger)
        # if self.aux_loss_enabled:
        if hasattr(outputs, "aux_loss") and outputs.aux_loss is not None:
            aux_loss = outputs.aux_loss
            aux_loss = self.gather_alternate(aux_loss)
            self._metrics[mode]["retain_aux_loss"].append(aux_loss[0].item())
            self._metrics[mode]["forget_aux_loss"].append(aux_loss[1].item())

        with torch.no_grad():
            gathered_loss = self.gather_alternate(loss)
            self._metrics[mode]["retain_loss"].append(gathered_loss[0].item())
            self._metrics[mode]["forget_loss"].append(gathered_loss[1].item())

        self.compute_stats(mode)
        return (loss, outputs) if return_outputs else loss

    def compute_stats(self, mode: str):
        return
    
    def _compute_KL(
            self,
            model: PeftModelForCausalLM,
            inputs: dict[str, torch.Tensor | Any],
            logits: torch.Tensor,
            num_items_in_batch: torch.Tensor | None = None
        ) -> torch.Tensor:
        with torch.no_grad():
            target = self.compute_full_outputs_logits(model, inputs, return_outputs=False, num_items_in_batch=num_items_in_batch)
        kl = torch.nn.functional.kl_div(logits.log_softmax(dim=-1), target.log_softmax(dim=-1), log_target=True, reduction='none').sum(dim=-1)    # batch_size x seq_len
        if (TRAIN_LLM and
            self.num_virtual_tokens > 0
            and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING
        ):
            kl = kl[:, self.num_virtual_tokens :]
        if "attention_mask" in inputs:
            kl = torch.sum(kl * inputs["attention_mask"]) / inputs["attention_mask"].sum()
        elif "position_ids" in inputs:
            kl = torch.mean(kl)
        elif not TRAIN_LLM:
            kl = torch.mean(kl)
        else:
            raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.")
        return kl

    def compute_full_outputs_logits(
            self,
            model: PeftModelForCausalLM | DistributedDataParallel,
            inputs: dict[str, torch.Tensor | Any],
            return_outputs: bool = False,
            num_items_in_batch: torch.Tensor | None = None
        ) -> torch.Tensor:
        if hasattr(model, "module"):
            model = model.module
        if TRAIN_LLM:
            # Assumes model has a "full" adapter that corresponds to the full model
            if "labels" in inputs:
                inputs.pop("labels")
            if self.model_accepts_loss_kwargs:
                kwargs = {}
                if num_items_in_batch is not None:
                    kwargs["num_items_in_batch"] = num_items_in_batch
                inputs = {**inputs, **kwargs}

            original_adapter = model.active_adapter
            model.set_adapter("full")
            with self.accelerator.no_sync(model):  # Gather gradients from different GPUs separately
                outputs = model(**inputs)
            model.set_adapter(original_adapter)
        elif isinstance(model, ResNet20ForCIFAR):
            model.frozen_model[0].eval()
            outputs = model.frozen_model[0](**inputs)
        else:
            raise NotImplementedError(f"Full model logits computation not implemented for non-PEFT models. ({type(model)})")

        return outputs.logits

class GATrainer(CustomTrainer):
    def __init__(self, alternate_gpu=False, *args, **kwargs):
        assert not alternate_gpu, "GATrainer requires alternate_gpu to be False."
        super().__init__(alternate_gpu, *args, **kwargs)
    # Assumes training dataset only contains forget data
    def compute_loss(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        return_outputs: bool = False,
        num_items_in_batch: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, Any]:
        mode = "train" if self.model.training else "eval"
        loss, outputs = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)
        if mode == "train":
            self._metrics[mode]["forget_loss"].append(loss.item())
            self._metrics[mode]["forget_entropy"].extend(self._metrics[mode].get("entropy", []))
            del self._metrics[mode]["entropy"]
            if TRAIN_LLM:
                self._metrics[mode]["forget_token_accuracy"].extend(self._metrics[mode].get("mean_token_accuracy", []))
                del self._metrics[mode]["mean_token_accuracy"]
            else:
                self._metrics[mode]["forget_accuracy"].extend(self._metrics[mode].get("accuracy", []))
                del self._metrics[mode]["accuracy"]

        if mode != "train":
            pass
        else:
            loss = -loss
        if return_outputs:
            return loss, outputs
        else:
            return loss

class GDTrainer(CustomTrainer):
    def __init__(self, alternate_gpu=False, *args, **kwargs):
        assert not alternate_gpu, "GDTrainer requires alternate_gpu to be False."
        super().__init__(alternate_gpu, *args, **kwargs)
    # Assumes training dataset only contains forget data
    def compute_loss(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        return_outputs: bool = False,
        num_items_in_batch: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, Any]:
        mode = "train" if self.model.training else "eval"
        loss, outputs = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)
        if mode == "train":
            self._metrics[mode]["retain_loss"].append(loss.item())
            self._metrics[mode]["retain_entropy"].extend(self._metrics[mode].get("entropy", []))
            del self._metrics[mode]["entropy"]
            if TRAIN_LLM:
                self._metrics[mode]["retain_token_accuracy"].extend(self._metrics[mode].get("mean_token_accuracy", []))
                del self._metrics[mode]["mean_token_accuracy"]
            else:
                self._metrics[mode]["retain_accuracy"].extend(self._metrics[mode].get("accuracy", []))
                del self._metrics[mode]["accuracy"]
        if return_outputs:
            return loss, outputs
        else:
            return loss

class GDiffTrainer(CustomTrainer):
    # GA on forget, KL on retain
    def __init__(self, alternate_gpu: bool = True, *args, **kwargs):
        assert alternate_gpu, "KLTrainer requires alternate_gpu to be True."
        super().__init__(alternate_gpu=alternate_gpu, *args, **kwargs)

    # Training dataset should contain both retain and forget data
    # partitioned correctly to distribute across GPUs
    def compute_loss(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        return_outputs: bool = False,
        num_items_in_batch: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, Any]:
        mode = "train" if self.model.training else "eval"
        loss, outputs = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)
        if mode != "train":
            pass
        else:
            if self.accelerator.process_index % 2 == 0:
                # Retain GPU
                pass
            else:
                # Forget GPU
                loss = -loss
        if return_outputs:
            return loss, outputs
        else:
            return loss

class KLTrainer(CustomTrainer):
    # GA on forget, KL on retain
    def __init__(self, alternate_gpu: bool = True, *args, **kwargs):
        assert alternate_gpu, "KLTrainer requires alternate_gpu to be True."
        super().__init__(alternate_gpu=alternate_gpu, *args, **kwargs)

    # Training dataset should contain both retain and forget data
    # partitioned correctly to distribute across GPUs
    def compute_loss(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        return_outputs: bool = False,
        num_items_in_batch: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, Any]:
        mode = "train" if self.model.training else "eval"
        loss, outputs = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)
        if mode != "train":
            with torch.no_grad():
                kl = self._compute_KL(model, inputs, outputs.logits, num_items_in_batch)
            kl = self.accelerator.gather_for_metrics(kl).mean()
            self._metrics[mode]["KL"].append(kl.item())
        else:
            with self.accelerator.no_sync(model):
                kl = self._compute_KL(model, inputs, outputs.logits, num_items_in_batch)
            if self.accelerator.process_index % 2 == 0:
                # Retain GPU
                loss = kl
            else:
                # Forget GPU
                loss = -loss
            with torch.no_grad():
                kl = self.gather_alternate(kl)
            self._metrics[mode]["retain_KL"].append(kl[0].mean().item())
            self._metrics[mode]["forget_KL"].append(kl[1].mean().item())
        if return_outputs:
            return loss, outputs
        else:
            return loss
        
class SCRUBTrainer(CustomTrainer):
    # SCRUB Trainer: alpha * KL + gamma * CE on retain, -KL on forget
    def __init__(self, alternate_gpu: bool = True, alpha=1.0, gamma = 1.0, *args, **kwargs):
        assert alternate_gpu, "SCRUBTrainer requires alternate_gpu to be True."
        super().__init__(alternate_gpu=alternate_gpu, *args, **kwargs)
        self._alpha = alpha
        self._gamma = gamma

    def compute_loss(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        return_outputs: bool = False,
        num_items_in_batch: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, Any]:
        mode = "train" if self.model.training else "eval"
        loss, outputs = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)
        if mode != "train":
            with torch.no_grad():
                kl = self._compute_KL(model, inputs, outputs.logits, num_items_in_batch)
            kl = self.accelerator.gather_for_metrics(kl).mean()
            self._metrics[mode]["KL"].append(kl.item())
        else:
            with self.accelerator.no_sync(model):
                kl = self._compute_KL(model, inputs, outputs.logits, num_items_in_batch)
            if self.accelerator.process_index % 2 == 0:
                # Retain GPU
                loss = self._alpha * kl + self._gamma * loss
            else:
                # Forget GPU
                loss = -kl
            with torch.no_grad():
                kl = self.gather_alternate(kl)
            self._metrics[mode]["retain_KL"].append(kl[0].mean().item())
            self._metrics[mode]["forget_KL"].append(kl[1].mean().item())
        if return_outputs:
            return loss, outputs
        else:
            return loss

class MyTrainer(CustomTrainer):
    # Training dataset should contain both retain and forget data
    # partitioned correctly to distribute across GPUs
    def __init__(self, alternate_gpu: bool = True, *args, **kwargs):
        # assert alternate_gpu, "MyTrainer requires alternate_gpu to be True."
        super().__init__(alternate_gpu=alternate_gpu, *args, **kwargs)

    def my_log(self, *args, **kwargs):
        optimizer = self.optimizer
        if optimizer is not None:
            torch.cuda.synchronize()
            if isinstance(optimizer, AcceleratedOptimizer):
                optimizer = optimizer.optimizer
            if "feasibility" in optimizer.logs:
                optimizer.logs["max_feasibility"].append(max(optimizer.logs["feasibility"]))
        super().my_log(*args, **kwargs)

    def compute_stats(self, mode: str):
        if mode != "train":
            return
        if isinstance(self.optimizer, AcceleratedOptimizer) and isinstance(self.optimizer.optimizer, MyOptimizer):
            stats = self.accelerator.gather_for_metrics(self.optimizer.optimizer.stats).reshape(-1, self.optimizer.optimizer.stats.size(0)).sum(dim=0)
            stats = stats / stats.sum()  # Normalize to get proportions
            stats = stats.cpu()
            # Do not zero stats, as we are accumulating over the entire training
            self._metrics[mode]["branch_1"] = [stats[0].item(),]
            self._metrics[mode]["branch_2"] = [stats[1].item(),]
            self._metrics[mode]["failed_feasibility"] = [stats[2].item(),]
            self._metrics[mode]["nan_in_delta"] = [stats[3].item(),]

class ThresholdStoppingCallback(TrainerCallback):
    """Callback to stop training when a specified metric crosses a threshold."""

    def __init__(self, stop_on_failed_feasibility=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if stop_on_failed_feasibility:
            self.failed_feasibility_threshold = 0.0
        else:
            self.failed_feasibility_threshold = 0.99
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            # # Check retain loss
            # retain_loss = max((i for i in (logs.get("retain_loss"), logs.get("eval_retain_loss")) if i is not None), default=None)
            # if retain_loss is not None and (torch.tensor(retain_loss) > 0.4):
            #     print(f"⚠️ Stopping training due to high retain_loss: {retain_loss}")
            #     control.should_training_stop = True
            #     return control
            
            # # Check forget loss
            # forget_loss = min((i for i in (logs.get("forget_loss"), logs.get("eval_forget_loss")) if i is not None), default=None)
            # if forget_loss is not None and (torch.tensor(forget_loss) < 0.1):
            #     print(f"⚠️ Stopping training due to low forget_loss: {forget_loss}")
            #     control.should_training_stop = True
            #     return control
            
            # # Check retain token accuracy
            # retain_token_accuracy = min((i for i in (logs.get("retain_token_accuracy"), logs.get("eval_retain_token_accuracy")) if i is not None), default=None)
            # if retain_token_accuracy is not None and torch.tensor(retain_token_accuracy) < 0.85:
            #     print(f"⚠️ Stopping training due to low retain_token_accuracy: {retain_token_accuracy}")
            #     control.should_training_stop = True
            #     return control
            
            # # Check forget token accuracy
            # forget_token_accuracy = max((i for i in (logs.get("forget_token_accuracy"), logs.get("eval_forget_token_accuracy")) if i is not None), default=None)
            # if forget_token_accuracy is not None and torch.tensor(forget_token_accuracy) > 0.98:
            #     print(f"⚠️ Stopping training due to high forget_token_accuracy: {forget_token_accuracy}")
            #     control.should_training_stop = True
            #     return control
            
            # Check feasibility failures
            failed_feasibility = logs.get("failed_feasibility")
            if failed_feasibility is not None and torch.tensor(failed_feasibility) > self.failed_feasibility_threshold:
                print(f"⚠️ Stopping training due to high failed_feasibility rate: {failed_feasibility}")
                control.should_training_stop = True
                return control

        return control