from functools import partial

import torch
import numpy as np

from unlearning_methods.helper import sto_cubic_func_gdiff, sto_cubic_func_hf_gdiff, utils


def unlearn(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):
    """
    Lessons:
    1. Approximation gets better when sample size increases. 
    2. On the contrary, stochastic step benefits the optimization by adding noise to the descent direction, degrading the mutual information between the remaining and the forgotten data.

    Notes:
    1. If encounter the issue about derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented,
    use the suggestion in this thread (https://github.com/pytorch/pytorch/issues/116350) and modify the corresponding local file of https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py.
    
    """

    trainer_init_kwargs.model = model
    trainer = trainer_init_func(**vars(trainer_init_kwargs))
    #TODO
    scr_step_func_gdiff = partial(
        sto_cubic_func_hf_gdiff.hf_stochastic_cubic_step, trainer=trainer)

    learning_rate = config.learning_rate
    # lr_scheduler = lambda step, lr: lr * 0.5 if (step % 5 == 0) else lr
    lr_scheduler = lambda step, lr: lr
    oom_tol = 5

    for step in range(config.num_outer_steps):
        print(f"Stochastic Step {step + 1}, lr = {learning_rate}")
        num_train_samples = len(retain_set)
        retain_train_ids = list(range(num_train_samples))
        retain_grad_ids = np.random.choice(
            retain_train_ids, 
            config.grad_sample_size, 
            replace=False,
        ).tolist()
        hess_ids = np.random.choice(
            retain_train_ids,
            config.hess_sample_size,
            replace=False,
        ) .tolist()
        retain_grad_batch = utils.sample(retain_set, retain_grad_ids) 
        hess_batch = utils.sample(retain_set, hess_ids)
        
        forget_train_ids = list(range(len(forget_set)))
        if len(forget_set) > config.grad_sample_size:
            forget_grad_ids = np.random.choice(
                forget_train_ids, min(len(forget_set), config.grad_sample_size), replace=False,
            )
            forget_grad_batch = utils.sample(forget_set, forget_grad_ids)
        else:
            forget_grad_batch = forget_set
        trainer.train_dataset = retain_grad_batch
        trainer._train_batch_size = config.grad_sample_size    # so that dataloader contains 1 batch
        retain_grad_batchloader = trainer.get_train_dataloader()
        trainer.train_dataset = forget_grad_batch
        trainer._train_batch_size = config.grad_sample_size    # so that dataloader contains 1 batch
        forget_grad_batchloader = trainer.get_train_dataloader()
        trainer.train_dataset = hess_batch
        trainer._train_batch_size = config.hess_sample_size
        hess_batchloader = trainer.get_train_dataloader()
        
        try:
            scr_step_func_gdiff(
                model, 
                retain_grad_batchloader=retain_grad_batchloader,
                forget_grad_batchloader=forget_grad_batchloader,
                hess_batchloader=hess_batchloader,
                M=config.M,
                num_steps=config.num_inner_steps,
                learning_rate=learning_rate,
                forget_coeff=config.forget_coeff,
                tofu=config.tofu,
                device=device, 
            )
        except RuntimeError as error:
            if "out of memory" in str(error):
                if oom_tol == 0:
                    raise "OOM too many times."
                else:
                    oom_tol -= 1
                    print("[WARNING] Recovered from OOM.")
                continue
            else:
                raise error
                
        learning_rate = lr_scheduler(step + 1, learning_rate) 
