from functools import partial

import torch
import numpy as np

from helper import sto_cubic_func_gdiff, sto_cubic_func_hf_gdiff, utils


def scr_newton_gdiff(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):
    """
    Lessons:
    1. Approximation gets better when sample size increases. 
    2. On the contrary, stochastic step benefits the optimization by adding noise to the descent direction, degrading the mutual information between the remaining and the forgotten data.

    Notes:
    1. If encounter the issue about derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented,
    use the suggestion in this thread (https://github.com/pytorch/pytorch/issues/116350) and modify the corresponding local file of https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py.
    
    """
    if config.llama:
        trainer_init_kwargs.model = model
        trainer = trainer_init_func(**vars(trainer_init_kwargs))
        #TODO
        scr_step_func_gdiff = partial(
            sto_cubic_func_hf_gdiff.hf_stochastic_cubic_step, trainer=trainer)
    else:
        loss_fn = getattr(torch.nn, config.loss)()
        scr_step_func_gdiff = partial(
            sto_cubic_func_gdiff.stochastic_cubic_step_gdiff, loss_fn=loss_fn)

    learning_rate = config.learning_rate
    # lr_scheduler = lambda step, lr: lr * 0.5 if (step % 5 == 0) else lr
    lr_scheduler = lambda step, lr: lr
    oom_tol = 5

    for step in range(config.num_outer_steps):
        print(f"Stochastic Step {step + 1}, lr = {learning_rate}")
        num_train_samples = len(retain_set)
        retain_train_ids = list(range(num_train_samples))
        retain_grad_ids = np.random.choice(
            retain_train_ids, 
            config.grad_sample_size, 
            replace=False,
        )
        hess_ids = np.random.choice(
            retain_train_ids,
            config.hess_sample_size,
            replace=False,
        ) 
        retain_grad_batch = utils.sample(retain_set, retain_grad_ids) 
        hess_batch = utils.sample(retain_set, hess_ids)

        forget_train_ids = list(range(len(forget_set)))
        if len(forget_set) > config.grad_sample_size:
            forget_grad_ids = np.random.choice(
                forget_train_ids, min(len(forget_set), config.grad_sample_size), replace=False,
            )
            forget_grad_batch = utils.sample(forget_set, forget_grad_ids)
        else:
            forget_grad_batch = forget_set

        if config.llama:
            trainer.train_dataset = retain_grad_batch
            trainer._train_batch_size = config.grad_sample_size    # so that dataloader contains 1 batch
            retain_grad_batchloader = trainer.get_train_dataloader()
            trainer.train_dataset = forget_grad_batch
            trainer._train_batch_size = config.grad_sample_size    # so that dataloader contains 1 batch
            forget_grad_batchloader = trainer.get_train_dataloader()
            trainer.train_dataset = hess_batch
            trainer._train_batch_size = config.hess_sample_size
            hess_batchloader = trainer.get_train_dataloader()
        else:
            retain_grad_batchloader = utils.get_dataloader(retain_grad_batch, 
                                                    shuffle=False, 
                                                    batch_size=config.grad_sample_size)
            forget_grad_batchloader = utils.get_dataloader(forget_set,
                                                           shuffle=True,
                                                           batch_size=config.grad_sample_size)
            hess_batchloader = utils.get_dataloader(hess_batch, 
                                                    shuffle=False, 
                                                    batch_size=config.hess_sample_size)
        try:
            if config.llama:
                scr_step_func_gdiff(
                    model, 
                    retain_grad_batchloader=retain_grad_batchloader,
                    forget_grad_batchloader=forget_grad_batchloader,
                    hess_batchloader=hess_batchloader,
                    M=config.M,
                    num_steps=config.num_inner_steps,
                    learning_rate=learning_rate,
                    forget_coeff=config.forget_coeff,
                    tofu=config.tofu,
                    device=device, 
                )
            else:
                scr_step_func_gdiff(
                    model, 
                    retain_grad_batchloader=retain_grad_batchloader,
                    forget_grad_batchloader=forget_grad_batchloader, 
                    hess_batchloader=hess_batchloader,
                    M=config.M,
                    num_steps=config.num_inner_steps,
                    learning_rate=learning_rate,
                    sigma=config.sigma,
                    forget_coeff=config.forget_coeff,
                    device=device,
                )
        except RuntimeError as error:
            if "out of memory" in str(error):
                if oom_tol == 0:
                    raise "OOM too many times."
                else:
                    oom_tol -= 1
                    print("[WARNING] Recovered from OOM.")
                continue
            else:
                raise error
                
        learning_rate = lr_scheduler(step + 1, learning_rate) 
