from functools import partial

import torch
import numpy as np

from helper import sto_cubic_func, sto_cubic_func_hf, utils


def scr_newton(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):
    """
    Lessons:
    1. Approximation gets better when sample size increases. 
    2. On the contrary, stochastic step benefits the optimization by adding noise to the descent direction, degrading the mutual information between the remaining and the forgotten data.

    Notes:
    1. If encounter the issue about derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented,
    use the suggestion in this thread (https://github.com/pytorch/pytorch/issues/116350) and modify the corresponding local file of https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py.
    
    """
    if config.llama:
        trainer_init_kwargs.model = model
        trainer = trainer_init_func(**vars(trainer_init_kwargs))
        scr_step_func = partial(sto_cubic_func_hf.hf_stochastic_cubic_step, trainer=trainer)
    else:
        loss_fn = getattr(torch.nn, config.loss)()
        scr_step_func = partial(sto_cubic_func.stochastic_cubic_step, loss_fn=loss_fn)

    learning_rate = config.learning_rate
    # lr_scheduler = lambda step, lr: lr * 0.5 if (step % 5 == 0) else lr
    lr_scheduler = lambda step, lr: lr
    oom_tol = 5
    
    for step in range(config.num_outer_steps):
        print(f"Stochastic Step {step + 1}, lr = {learning_rate}")
        num_train_samples = len(retain_set)
        train_ids = list(range(num_train_samples))
        grad_ids = np.random.choice(train_ids, config.grad_sample_size, replace=False)
        hess_ids = np.random.choice(train_ids, config.hess_sample_size, replace=False)
        grad_batch = utils.sample(retain_set, grad_ids) 
        hess_batch = utils.sample(retain_set, hess_ids)

        if config.llama:
            trainer.train_dataset = grad_batch
            trainer._train_batch_size = config.grad_sample_size    # so that dataloader contains 1 batch
            grad_batchloader = trainer.get_train_dataloader()
            trainer.train_dataset = hess_batch
            trainer._train_batch_size = config.hess_sample_size
            hess_batchloader = trainer.get_train_dataloader()
        else:
            grad_batchloader = utils.get_dataloader(grad_batch, shuffle=False, batch_size=config.grad_sample_size)
            hess_batchloader = utils.get_dataloader(hess_batch, shuffle=False, batch_size=config.hess_sample_size)
        try:
            if config.llama:
                scr_step_func(
                    model, 
                    grad_batchloader=grad_batchloader, hess_batchloader=hess_batchloader, 
                    M=config.M, 
                    num_steps=config.num_inner_steps, learning_rate=learning_rate, 
                    device=device, tofu=config.tofu,
                )
            else:
                scr_step_func(
                    model, 
                    grad_batchloader=grad_batchloader, hess_batchloader=hess_batchloader, 
                    M=config.M, 
                    num_steps=config.num_inner_steps, learning_rate=learning_rate, 
                    sigma=config.sigma, 
                    device=device,
                )
        except RuntimeError as error:
            if "out of memory" in str(error):
                if oom_tol == 0:
                    raise "OOM too many times."
                else:
                    oom_tol -= 1
                    print("[WARNING] Recovered from OOM.")
                continue
            else:
                raise error
                
        learning_rate = lr_scheduler(step + 1, learning_rate) 
