from collections import defaultdict
from typing import List
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)   # Ignore warning from itertools partial

from functools import partial
import argparse
import torch
from datasets import load_dataset, Dataset, concatenate_datasets, load_from_disk, Array3D
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback, set_seed, TrainingArguments
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModelForCausalLM
from trl.trainer.sft_trainer import SFTTrainer, SFTConfig
from itertools import batched, cycle
import random
import os
from unlearning_utils import CustomOptimizer, MyOptimizer, ThresholdStoppingCallback, CustomTrainer, MyTrainer, GATrainer, GDiffTrainer, SCRUBTrainer, KLTrainer, GDTrainer
from custom_callbacks import MyLoggingCallback, NaNStoppingCallback
from time import time
import gc
import atexit

os.environ["TOKENIZERS_PARALLELISM"] = "false"
TRAIN_LLM = os.getenv("TRAIN_LLM", "false").lower() == "true"

if not TRAIN_LLM:
    from train_cnn_utils import get_pad_dataset, load_model_and_processor, load_datasets, preprocess_function
else:
    from train_llm_utils import get_pad_dataset, load_model_and_processor, load_datasets, preprocess_function

def parse_args():
    parser = argparse.ArgumentParser(description="Train a model with PEFT using SFTTrainer")
    
    # Model arguments
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for initialization")
    parser.add_argument("--shuffle_seed", type=int, default=42,
                        help="Random seed for shuffling the dataset")
    parser.add_argument("--model_name", type=str, default="meta-llama/Llama-2-7b-chat-hf",
                        help="Pretrained model name or path")
    parser.add_argument("--num_labels", type=int, default=10,
                        help="Number of labels for classification")
    parser.add_argument("--img_size", type=int, default=None,
                        help="Image size for the model (if applicable)")
    parser.add_argument("--output_dir", type=str, default="./results",
                        help="Output directory for model and checkpoints")
    parser.add_argument("--final_model_output_dir", type=str, default=None,
                        help="Output directory for final model (if different from output_dir)")

    parser.add_argument("--split_ratio", type=float, default=0.1,
                        help="Split ratio for hard unlearning")
    parser.add_argument("--hard_probability", type=float, default=0.0,
                        help="Hard unlearning: spilt retain forget subsets randomly by split_ratio rather than by class labels." \
                        "Value between 0 and 1 indicating the probability of mixing in non-class examples into forget set." \
                        "I.e., hard_probability=0.1 means 90% of the forget set are class examples while 10% are randomly sampled from the remainder," \
                        "hard_probability=1.0 means forget set is fully random samples.")
    parser.add_argument("--dataset_repeat", type=int, default=1,
                        help="Number of times to repeat the dataset (help speed up training dataloader)")

    parser.add_argument("--dataset_name", type=str, default="locuslab/TOFU",
                        help="Dataset name from HuggingFace. Full dataset if training, retain dataset if unlearning")
    parser.add_argument("--dataset_subset", type=str, default="full",
                        help="Dataset subset to use")
    parser.add_argument("--dataset_split", type=str, default="train",
                        help="Dataset split to use")
    parser.add_argument("--dataset_prompt_field", type=str, default="question",
                        help="Field name in dataset containing prompts")
    parser.add_argument("--dataset_response_field", type=str, default="answer",
                        help="Field name in dataset containing responses")

    parser.add_argument("--forget_dataset_name", type=str, default="locuslab/TOFU",
                        help="Forget dataset name from HuggingFace")
    parser.add_argument("--forget_dataset_subset", type=str, default="full",
                        help="Forget dataset subset to use")
    parser.add_argument("--forget_dataset_split", type=str, default="train",
                        help="Forget dataset split to use")
    parser.add_argument("--forget_dataset_prompt_field", type=str, default="question",
                        help="Field name in forget dataset containing prompts")
    parser.add_argument("--forget_dataset_response_field", type=str, default="answer",
                        help="Field name in forget dataset containing responses")

    parser.add_argument("--duplicate_dataset_name", type=str, default="locuslab/TOFU",
                        help="Duplicate dataset name from HuggingFace")
    parser.add_argument("--duplicate_dataset_subset", type=str, default="full",
                        help="Duplicate dataset subset to use")
    parser.add_argument("--duplicate_dataset_split", type=str, default="train",
                        help="Duplicate dataset split to use")
    parser.add_argument("--duplicate_dataset_prompt_field", type=str, default="question",
                        help="Field name in duplicate dataset containing prompts")
    parser.add_argument("--duplicate_dataset_response_field", type=str, default="answer",
                        help="Field name in duplicate dataset containing responses")

    parser.add_argument("--add_forget_to_train", action="store_true", default=False,
                        help="Whether to add forget examples to the train set")
    parser.add_argument("--add_duplicate_to_retain", action="store_true", default=False,
                        help="Whether to add duplicate examples to the retain set")
    parser.add_argument("--eval_on_subsets", action="store_true", default=False,
                        help="Whether to evaluate on all subsets")
    parser.add_argument("--eval_on_forget", action="store_true", default=False,
                        help="Whether to evaluate on forgotten examples")
    parser.add_argument("--interlace_forget", action="store_true", default=False,
                        help="Whether to interlace forgotten examples during training")
    parser.add_argument("--shuffle_dataset", action="store_true", default=False,
                        help="Whether to shuffle the dataset before training")

    # LoRA arguments
    parser.add_argument("--add_lora", action="store_true", default=False,
                        help="Whether to add LoRA adapters")
    parser.add_argument("--lora_name", type=str, default="added",
                        help="Name of LoRA adapter to add (if any)")
    parser.add_argument("--train_existing", action="store_true", default=False,
                        help="Train existing LoRA parameters")
    parser.add_argument("--lora_r", type=int, default=16,
                        help="LoRA rank")
    parser.add_argument("--lora_alpha", type=int, default=32,
                        help="LoRA alpha")
    parser.add_argument("--lora_dropout", type=float, default=0.05,
                        help="LoRA dropout")

    # Custom optimizer arguments
    parser.add_argument("--use_custom_optim", action="store_true", default=False,
                        help="Use custom optimizer")
    parser.add_argument("--B", type=float, default=0.0,
                        help="Custom optimizer parameter B")
    parser.add_argument("--R", type=float, default=None,
                        help="Custom optimizer parameter R")
    parser.add_argument("--use_lr", action="store_true", default=False,
                        help="Use learning rate instead of R in custom optimizer")
    parser.add_argument("--dual_update", action="store_true", default=False,
                        help="Use dual update in custom optimizer")
    parser.add_argument("--continue_on_failed_feasibility", action="store_true", default=False,
                        help="Continue training if feasibility condition fails in custom optimizer")
    parser.add_argument("--full_grad", action="store_true", default=False,
                        help="Use full gradient computation in custom optimizer")
    parser.add_argument("--distribute_B_softmax_dp", action="store_true", default=False,
                        help="Dynamically distribute B based on softmax of dot product of gr and gf")
    parser.add_argument("--distribute_B_gr_gf_norm", action="store_true", default=False,
                        help="Dynamically distribute B based on product of gr and gf norm")
    parser.add_argument("--no_distribute_B_norm", action="store_true", default=False,
                        help="Normalize distributed B by gradient norm")
    
    parser.add_argument("--use_optimizer", action="store_true", default=False,
                        help="Also use optimizer on top of custom optimizer")
    
    parser.add_argument("--freeze_batchnorm", action="store_true", default=False,
                        help="Freeze BatchNorm layers during training")
    
    # Other optimizers
    parser.add_argument("--use_gd", action="store_true", default=False,
                        help="Use GD optimizer")
    parser.add_argument("--use_ga", action="store_true", default=False,
                        help="Use GA optimizer")
    parser.add_argument("--use_gdiff", action="store_true", default=False,
                        help="Use GDiff optimizer")
    parser.add_argument("--use_kl", action="store_true", default=False,
                        help="Use KL divergence optimizer")
    parser.add_argument("--use_scrub", action="store_true", default=False,
                        help="Use Scrub optimizer")
    parser.add_argument("--scrub_alpha", type=float, default=0.001,
                        help="SCRUB alpha parameter")
    parser.add_argument("--scrub_gamma", type=float, default=0.99,
                        help="SCRUB gamma parameter")
    
    # Training arguments
    parser.add_argument("--use_sft", type=bool, default=True,
                        help="Whether to use SFTTrainer")

    parser.add_argument("--num_train_epochs", type=int, default=10,
                        help="Number of training epochs")
    parser.add_argument("--per_device_train_batch_size", type=int, default=8,
                        help="Batch size per device")
    parser.add_argument("--per_device_eval_batch_size", type=int, default=8,
                        help="Evaluation batch size per device")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
                        help="Gradient accumulation steps")
    parser.add_argument("--learning_rate", type=float, default=1e-4,
                        help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.0,
                        help="Weight decay")
    parser.add_argument("--logging_steps", type=float, default=0.25,
                        help="Logging steps")
    parser.add_argument("--logging_strategy", type=str, default="steps",
                        help="Logging strategy")
    parser.add_argument("--save_steps", type=float, default=None,
                        help="Save steps")
    parser.add_argument("--save_strategy", type=str, default="no",
                        help="Save strategy")
    parser.add_argument("--eval_steps", type=float, default=None,
                        help="Evaluation steps")
    parser.add_argument("--longer_eval_steps", type=int, default=None,
                        help="Evaluation steps")
    parser.add_argument("--eval_strategy", type=str, default="epoch",
                        help="Evaluation strategy")
    parser.add_argument("--eval_on_start", action="store_true", default=False,
                        help="Evaluate on start of training")
    parser.add_argument("--warmup_ratio", type=float, default=0.0,
                        help="Warmup ratio")
    parser.add_argument("--lr_scheduler_type", type=str, default="constant",
                        help="Learning rate scheduler type")
    parser.add_argument("--max_grad_norm", type=float, default=1.0,
                        help="Max gradient norm")
    parser.add_argument("--fp16", action="store_true", default=False,
                        help="Use FP16 training")
    parser.add_argument("--bf16", action="store_true", default=False,
                        help="Use BF16 training")
    parser.add_argument("--gradient_checkpointing", action="store_true", default=False,
                        help="Use gradient checkpointing")
    parser.add_argument("--group_by_length", action="store_true", default=False,
                        help="Group sequences by length")
    parser.add_argument("--ddp_find_unused_parameters", action="store_true", default=False,
                        help="Enable find_unused_parameters in DDP (adds overhead)")
    parser.add_argument("--torch_empty_cache_steps", type=int, default=None,
                        help="Number of steps between torch.cuda.empty_cache() calls")
    
    # Wandb arguments
    parser.add_argument("--use_wandb", action="store_true", default=False,
                        help="Use Weights & Biases for logging")
    parser.add_argument("--wandb_project", type=str, default="hard",
                        help="Wandb project name")
    parser.add_argument("--wandb_entity", type=str, default=None,
                        help="Wandb entity (username or team name)")
    parser.add_argument("--wandb_run_name", type=str, default=None,
                        help="Wandb run name")
    parser.add_argument("--wandb_tags", type=str, nargs="+", default=None,
                        help="Wandb tags for the run")
    parser.add_argument("--wandb_notes", type=str, default=None,
                        help="Notes for the wandb run")
    
    parser.add_argument("--debug", action="store_true", default=False,
                        help="Debug mode for custom optimizer")
    parser.add_argument("--dry_run", action="store_true", default=False,
                        help="Dry run mode")

    return parser.parse_args()

def main():
    args = parse_args()
    unlearning = args.use_custom_optim or args.use_ga or args.use_gd or args.use_gdiff or args.use_scrub or args.use_kl
    if args.use_custom_optim:
        if args.use_lr:
            R = args.learning_rate * args.max_grad_norm
        else:
            R = args.R
        assert args.B <= R * args.max_grad_norm, f"B ({args.B}) must be less than or equal to R * max_grad_norm ({R * args.max_grad_norm})"
        if not args.use_lr:
            assert args.R is not None, "R must be specified if not using learning rate."
        args.output_dir = f"{args.output_dir}/" + ("dual_" if args.dual_update else "") + (f"R{args.R}/" if not args.use_lr else f"lr{args.learning_rate}") + f"_B{args.B}" + ("_continueOnFail" if args.continue_on_failed_feasibility else "")
        if args.final_model_output_dir is not None:
            args.final_model_output_dir = f"{args.final_model_output_dir}/" + ("dual_" if args.dual_update else "") + (f"R{args.R}/" if not args.use_lr else f"lr{args.learning_rate}") + f"_B{args.B}" + ("_continueOnFail" if args.continue_on_failed_feasibility else "") + (f"_{args.max_grad_norm}gn" if args.max_grad_norm <1e3 else "")
    os.makedirs(args.output_dir, exist_ok=True)
    if args.final_model_output_dir is not None:
        os.makedirs(args.final_model_output_dir, exist_ok=True)

    if args.dry_run:
        args.use_wandb = False

    if args.use_wandb:
        os.environ["WANDB_PROJECT"] = args.wandb_project
        if args.wandb_entity:
            os.environ["WANDB_ENTITY"] = args.wandb_entity
        if args.wandb_run_name:
            os.environ["WANDB_RUN_NAME"] = args.wandb_run_name
        if args.wandb_tags:
            os.environ["WANDB_TAGS"] = ",".join(args.wandb_tags)
        if args.wandb_notes:
            os.environ["WANDB_NOTES"] = args.wandb_notes

    # Load model and processor
    load_full=args.use_scrub or args.use_kl
    model, processor = load_model_and_processor(
        model_name=args.model_name,
        seed=args.seed,
        add_lora=args.add_lora,
        lora_name=args.lora_name,
        load_full=load_full,
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        num_labels=args.num_labels,
        size=args.img_size,
        freeze_batchnorm=args.freeze_batchnorm,
    )
    
    # Load dataset
    dataset, duplicate_dataset, forget_dataset = load_datasets(
        processor,
        **vars(args)
    )
    
    if TRAIN_LLM:
        # Preprocess for LLM first as it doesnt take much storage
        dataset = dataset.map(
            partial(preprocess_function, processor, args.dataset_prompt_field, args.dataset_response_field),
            remove_columns=dataset.column_names,
            # with_indices=True,
            load_from_cache_file=True,
            cache_file_name=f"./dataset_cache/cache_{args.dataset_name.replace('/', '_')}_{args.dataset_subset}_{args.dataset_split}_{args.dataset_prompt_field}_{args.dataset_response_field}.arrow"
        )
        if duplicate_dataset is not None and len(duplicate_dataset) > 0:
            duplicate_dataset = duplicate_dataset.map(
                partial(preprocess_function, processor, args.duplicate_dataset_prompt_field, args.duplicate_dataset_response_field),
                remove_columns=duplicate_dataset.column_names,
                # with_indices=True,
                load_from_cache_file=True,
                cache_file_name=f"./dataset_cache/cache_{args.duplicate_dataset_name.replace('/', '_')}_{args.duplicate_dataset_subset}_{args.duplicate_dataset_split}_{args.duplicate_dataset_prompt_field}_{args.duplicate_dataset_response_field}_duplicate.arrow"
            )
        forget_dataset = forget_dataset.map(
            partial(preprocess_function, processor, args.forget_dataset_prompt_field, args.forget_dataset_response_field),
            remove_columns=forget_dataset.column_names,
            # with_indices=True,
            load_from_cache_file=True,
            cache_file_name=f"./dataset_cache/cache_{args.forget_dataset_name.replace('/', '_')}_{args.forget_dataset_subset}_{args.forget_dataset_split}_{args.forget_dataset_prompt_field}_{args.forget_dataset_response_field}_forget.arrow"
        )
    else:
        # Non-LLM preporcessing already done in load_datasets
        pass
    if args.add_duplicate_to_retain:
        retain_dataset = concatenate_datasets([dataset, duplicate_dataset])
    else:
        retain_dataset = dataset
    if args.use_ga:
        train_dataset = forget_dataset
    elif args.add_forget_to_train:
        train_dataset = concatenate_datasets([retain_dataset, forget_dataset])
    else:
        train_dataset = retain_dataset
    if args.shuffle_dataset:
        train_indices = list(range(len(train_dataset)))
        forget_indices = list(range(len(forget_dataset)))
        random.seed(args.shuffle_seed)
        random.shuffle(train_indices)
        random.shuffle(forget_indices)
        train_dataset = train_dataset.select(train_indices)
        forget_dataset = forget_dataset.select(forget_indices)

    dataset_pad = 0
    if args.interlace_forget:
        assert not args.use_gd and not args.use_ga and not args.add_forget_to_train, "Cannot use interlace_forget with use_ga or add_forget_to_train."
        # Interlace forget samples into retain dataset for training
        batch_size = args.per_device_train_batch_size
        if batch_size == -1:
            batch_size = len(train_dataset) // int(os.environ.get("WORLD_SIZE", 1))
        train_indices = range(len(train_dataset))
        forget_indices = range(len(forget_dataset))
        train_batched = batched(train_indices, batch_size)
        forget_batched = batched(cycle(forget_indices), batch_size)
        train_batched = (train_dataset.select(batch) for batch in train_batched)
        forget_batched = (forget_dataset.select(batch) for batch in forget_batched)
        interlaced_dataset = []  # [r1, r2, f1, f2, r3, r4, f3, f4, ...]
        for i, (train_batch, forget_batch) in enumerate(zip(train_batched, forget_batched)):
            # Pad last batch if needed, so that retain and forget batches are split onto different GPUs correctly
            if len(train_batch) < batch_size:
                dataset_pad = batch_size - len(train_batch)
                pad_dataset = get_pad_dataset(dataset_pad, processor=processor)
                train_batch = concatenate_datasets([train_batch, pad_dataset])
                forget_batch = concatenate_datasets([forget_batch.select(range(len(forget_batch) - dataset_pad)), pad_dataset])
            interlaced_dataset.append(train_batch)
            interlaced_dataset.append(forget_batch)
        train_dataset = concatenate_datasets(interlaced_dataset)
    elif args.use_gd:
        train_dataset = concatenate_datasets([retain_dataset, retain_dataset])
    elif args.use_ga:
        train_dataset = concatenate_datasets([forget_dataset, forget_dataset] * (len(retain_dataset) // len(forget_dataset)))
    else:
        pass
        # train_dataset = train_dataset

    # Training arguments
    per_device_train_batch_size = args.per_device_train_batch_size
    if per_device_train_batch_size == -1:
        per_device_train_batch_size = len(train_dataset) // int(os.environ.get("WORLD_SIZE", 1))
    training_args = {
        "seed": args.seed,
        "output_dir": args.output_dir,
        "num_train_epochs": args.num_train_epochs // args.dataset_repeat,
        "per_device_train_batch_size": per_device_train_batch_size,
        "per_device_eval_batch_size": args.per_device_eval_batch_size,
        "gradient_accumulation_steps": args.gradient_accumulation_steps,
        "gradient_checkpointing": args.gradient_checkpointing,
        "logging_steps": args.logging_steps if args.logging_steps >= 1 else 1,
        "logging_strategy": args.logging_strategy,
        "eval_steps": args.eval_steps,
        "eval_strategy": args.eval_strategy if (args.eval_on_forget or args.eval_on_subsets) else "no",
        "eval_on_start": args.eval_on_start,
        "save_steps": args.save_steps,
        "save_strategy": args.save_strategy,
        "learning_rate": args.learning_rate,
        "weight_decay": args.weight_decay,
        "fp16": args.fp16 and torch.cuda.is_available(),
        "bf16": args.bf16 and torch.cuda.is_available(),
        "max_grad_norm": args.max_grad_norm,
        "warmup_ratio": args.warmup_ratio,
        "group_by_length": args.group_by_length,
        "lr_scheduler_type": args.lr_scheduler_type,
        "report_to": "wandb" if args.use_wandb else "none",
        "ddp_find_unused_parameters": args.ddp_find_unused_parameters,
        "torch_empty_cache_steps": args.torch_empty_cache_steps,
        "remove_unused_columns": False,
        "torch_compile": True,
    }
    if not TRAIN_LLM:
        if not args.debug:
            num_workers = (os.cpu_count() // 8) - 1
            if num_workers > 1:
                training_args |= {
                    "dataloader_num_workers": num_workers,
                    # "dataloader_persistent_workers": True,
                    "dataloader_prefetch_factor": 4,
                    "dataloader_pin_memory": True,
                }
        training_args = TrainingArguments(**training_args)
    else:
        training_args |= {
            "completion_only_loss": True,
            "max_length": None,
        }
        training_args = SFTConfig(**training_args)
    
    # Initialize Trainer
    optimizer = None
    TrainerClass = MyTrainer
    nan_stopping_callback = NaNStoppingCallback()
    if unlearning:
        if args.use_optimizer:
            optimizer = torch.optim.AdamW(
                model.parameters(),
                lr=args.learning_rate
            )
        else:
            optimizer = None
        if args.use_custom_optim:
            optimizer = MyOptimizer(
                model,
                optimizer=optimizer,
                lr=args.learning_rate if args.use_lr else None,
                R=args.R,
                Q=args.B,
                dual_update=args.dual_update,
                stop_on_failed_feasibility=not args.continue_on_failed_feasibility,
                full_grad=args.full_grad,
                distribute_B_softmax_dp=args.distribute_B_softmax_dp,
                distribute_B_gr_gf_norm=args.distribute_B_gr_gf_norm,
                distribute_B_norm=not args.no_distribute_B_norm,
                debug=args.debug
            )
        else:
            optimizer = CustomOptimizer(
                model,
                defaults={'lr': args.learning_rate},
                optimizer=optimizer
            )
        callbacks: List[TrainerCallback] = [
            nan_stopping_callback,
            ThresholdStoppingCallback(stop_on_failed_feasibility=not args.continue_on_failed_feasibility)
            ]
        if True or args.logging_steps < 1.0:
            logging_interval = args.logging_steps
            callbacks.append(MyLoggingCallback(
                logging_interval=logging_interval,
                eval_interval=args.eval_steps,
                longer_eval_interval=args.longer_eval_steps
                ))
        if args.use_ga:
            TrainerClass = GATrainer
        elif args.use_gdiff:
            TrainerClass = GDiffTrainer
        elif args.use_scrub:
            TrainerClass = SCRUBTrainer
        elif args.use_kl:
            TrainerClass = KLTrainer
        elif args.use_gd:
            TrainerClass = GDTrainer
    else:
        callbacks: List[TrainerCallback] = [nan_stopping_callback]
    if args.eval_on_subsets:
        eval_dataset = {
            "retain": retain_dataset,
            "forget": forget_dataset
        }
        if duplicate_dataset is not None and len(duplicate_dataset) > 0:
            if TRAIN_LLM:
                eval_dataset["duplicate"] = duplicate_dataset
            else:
                eval_dataset["test"] = duplicate_dataset
    elif args.eval_on_forget:
        eval_dataset = forget_dataset
    else:
        eval_dataset = None
    if args.dataset_repeat > 1:
        train_dataset = train_dataset.repeat(args.dataset_repeat)
    if args.debug:
        train_dataset = train_dataset.select(range(min(32, len(train_dataset))))
    trainer: CustomTrainer = TrainerClass(
        alternate_gpu=args.use_custom_optim or args.use_gdiff or args.use_scrub or args.use_kl,
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        args=training_args,
        optimizers=(optimizer, None),
        callbacks=callbacks,
        debug=args.debug,
        processing_class=processor,
    )
    if unlearning:
        if not args.debug:
            assert trainer.accelerator.num_processes % 2 == 0, "Number of processes must be even."
        trainer.optimizer.set_accelerator(trainer.accelerator)
    nan_stopping_callback.accelerator = trainer.accelerator
    if not TRAIN_LLM and load_full:
        trainer.update_frozen_model()
    if isinstance(trainer, SCRUBTrainer):
        trainer._alpha = args.scrub_alpha  # SCRUB hyperparameter
        trainer._gamma = args.scrub_gamma  # SCRUB hyperparameter

    gc.collect()
    torch.cuda.empty_cache()
    # Train the model
    if trainer.is_world_process_zero():
        print("Training arguments:", args)
        if isinstance(model, PeftModelForCausalLM):
            model.print_trainable_parameters()
        print(f"Dataset size: {(len(train_dataset) - dataset_pad * (1 + args.interlace_forget)) // args.dataset_repeat} (padded to {len(train_dataset) // args.dataset_repeat})")
        print("Starting training...")
        # Log args to wandb after trainer initialization
        if args.use_wandb:
            import wandb
            run=wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=args.wandb_run_name, tags=args.wandb_tags, notes=args.wandb_notes)
            forget_pct = len(forget_dataset) / (len(dataset) + len(forget_dataset))
            method = "ga" if args.use_ga else "gd" if args.use_gd else "gdiff" if args.use_gdiff else "scrub" if args.use_scrub else "kl" if args.use_kl else ("ours_dual" if args.dual_update else "ours") if args.use_custom_optim else None
            wandb.config.update(
                {f"command/{k}": v for k, v in vars(args).items()} | {
                    "training_dataset_len": len(train_dataset),
                    "retain_dataset_len": len(dataset),
                    "forget_dataset_len": len(forget_dataset),
                    "duplicate_dataset_len": len(duplicate_dataset) if duplicate_dataset is not None else 0,
                    "forget_pct": int(round(100 * forget_pct)),
                    "ratio_b_lr": round(args.B/args.learning_rate, 8),
                    "ratio_b_r": round(args.B/args.R, 8) if args.R is not None else None,
                    "method": method
                    }, 
                allow_val_change=True)
            wandb.save("*.py")

    start_time = time()
    if not args.dry_run:
        trainer.train()
    training_time = time() - start_time
    if trainer.is_world_process_zero():
        print(f"Training completed in {training_time:.2f} seconds.")
        if args.use_wandb:
            run.summary["training_time_seconds"] = training_time
    
    # Save the fine-tuned model (trainer handles multi-process automatically)
    # Check that trainer has completed at least 1 epoch
    if args.save_strategy != "no":
        if trainer.state.epoch is None or trainer.state.epoch < 1.0:
            if trainer.is_world_process_zero() and not args.dry_run:
                failed_msg = f"Training terminated at epoch {trainer.state.epoch}, before 1 epoch. Skipping model save."
                print(failed_msg)
                with open(f"{args.output_dir}/training_skipped.txt", "w") as f:
                    f.write(failed_msg)
                if args.use_wandb:
                    run.summary["training_skipped"] = True
        elif not args.dry_run:
            if args.final_model_output_dir is not None:
                trainer.save_model(args.final_model_output_dir)
                if trainer.is_world_process_zero():
                    print(f"Model saved to {args.final_model_output_dir}")


def cleanup():
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        rank=torch.distributed.get_rank()
        print(f"Rank {rank} reached cleanup", flush=True)
        torch.distributed.destroy_process_group()
        print(f"Rank {rank} finished cleanup", flush=True)

if __name__ == "__main__":
    atexit.register(cleanup)
    main()
