import copy
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader

from helper.thirdparty.tofu.dataloader import custom_data_collator_forget, CustomTrainerForgetting


def npo_tofu(model, oracle_model, tokenizer, train_data, config, training_args):

    trainer = CustomTrainerForgetting(
            model=model,
            tokenizer=tokenizer,
            train_dataset=train_data,
            compute_metrics=None,   # the callback for computing metrics, None in this case since you're doing it in your callback
            args=training_args,
            data_collator=custom_data_collator_forget,
            oracle_model=oracle_model,
            forget_loss='npo',
            eval_cfg=None,      # turn off evaluate during unlearning
        )
    
    trainer.train()

def npo(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):
    forget_loader = DataLoader(forget_set, shuffle=True, batch_size=config.train_batch_size)
    criterion = getattr(torch.nn, config.loss)()
    optimizer_cls = getattr(torch.optim, config.optimizer)
    optimizer = optimizer_cls(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

    oracle_model = copy.deepcopy(model)
    oracle_model.eval()

    model.train()
    for epoch in tqdm(range(config.num_epochs), desc="NPO"):
        for forget_batch in forget_loader:
            optimizer.zero_grad()
            forget_inputs, forget_labels = forget_batch
            forget_outputs = model(forget_inputs.to(device))
            forget_loss_current = criterion(forget_outputs, forget_labels.to(device))
            
            with torch.no_grad():
                forget_outputs_oracle = oracle_model(forget_inputs.to(device))
                forget_loss_oracle = criterion(forget_outputs_oracle, forget_labels.to(device))   

            neg_log_ratios = forget_loss_current - forget_loss_oracle
            beta = 1.0
            loss = -F.logsigmoid(beta * neg_log_ratios).mean() * 2 / beta  

            loss.backward()
            optimizer.step()