import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from helper import cubic_func


def cr_newton(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
    M=None,
):
    assert not config.llama, "CR-Newton can't be run on Llama."

    retain_loader = DataLoader(retain_set, shuffle=True, batch_size=config.train_batch_size)
    criterion = getattr(torch.nn, config.loss)()
    
    alpha_list = []
    for i in tqdm(range(config.num_epochs), desc="CuReNU"):
        alpha = cubic_func.full_cubic_step(
            model, 
            criterion, 
            retain_loader, 
            M=M, 
            save_hess=config.save_hess,
            temp_dir=getattr(config, "temp_dir", None), 
            device=device,
        )
        alpha_list.append(alpha)

    unl_logs["alpha"] = alpha_list