import torch
import numpy as np
from functools import partial

from unlearning_methods.helper import cubic_func, utils


def stochastic_cubic_step_gdiff(
    model, loss_fn,
    retain_grad_batchloader,
    forget_grad_batchloader,
    hess_batchloader,
    M: float = 1,
    num_steps: int = 5,
    learning_rate: float = 0.001,
    sigma: float = 1e-5,
    forget_coeff: float = 0.01,
    device=None,
):

    tuple_params = tuple(param for param in model.parameters() if param.requires_grad)
    
    grad = cubic_func.gradient(
        model, loss_fn, retain_grad_batchloader, device=device,
    )
    grad = tuple(x.detach().cpu() for x in grad)
    utils.clear_cache()
    compute_hvp = hvp_func(model, loss_fn, hess_batchloader, device=device)                               
    
    forget_grad = cubic_func.gradient(model, loss_fn, forget_grad_batchloader, device=device)
    forget_grad = tuple(x.detach().cpu() for x in forget_grad)
        
    dw = gdiff_cubic_subsolver(
        model,
        tuple_params, 
        compute_hvp, 
        grad,
        forget_grad, 
        M=M, 
        num_steps=num_steps,
        learning_rate=learning_rate,
        sigma=sigma,
        forget_coeff=forget_coeff,
        device=device,
    )
    tuple_param_update = cubic_func.decompose_param_vector(dw, tuple_params)
    i = 0
    for param in model.parameters():
        if param.requires_grad:
            param.data += tuple_param_update[i]
            i += 1

def gdiff_cubic_subsolver(
    model,
    tuple_params, 
    tuple_hvp_fn: callable, 
    tuple_grad, 
    tuple_forget_grad,
    M: float,
    num_steps: int,
    learning_rate: float, 
    sigma: float, 
    forget_coeff: float,
    device=None,
):
    """
    Solve cubic subproblem using gradient descent 
    Algorithm 2 in Carmon and Duchi (2016) (https://arxiv.org/pdf/1612.00547)
    """
    grad = cubic_func.compose_param_vector(tuple_grad, tuple_params)
    grad = utils.convert_torch_to_numpy(grad)
    grad_norm = np.linalg.norm(grad)
    print("grad_norm:", grad_norm)

    forget_grad = cubic_func.compose_param_vector(tuple_forget_grad, tuple_params)
    forget_grad = utils.convert_torch_to_numpy(forget_grad)

    print("Setting x0 = cauchy point") 
    try:
        B_grad = hvp(tuple_hvp_fn, tuple_grad, device=device)
    except RuntimeError as error:
        if "out of memory" in str(error):
            print("Bad batch. Please try another batch.")
            for param in model.parameters():
                if param.grad is not None:
                    del param.grad
            utils.clear_cache()
        raise error
    
    utils.clear_cache()
    B_grad = cubic_func.compose_param_vector(B_grad, tuple_params)
    B_grad = B_grad.detach().cpu()
    B_grad = utils.convert_torch_to_numpy(B_grad)

    # a = [a[i] + a_mask[i] for i in range(len(a))]
        
    # # Compute Hessian-vector product (HVP) B[g]
    # hgp = vhp(self.f, tuple(param), tuple(grad))[1]
    
    # # Compute (g^T B[g]) / (ρ ||g||^2)
    # hgp = [hgp[i] * grad[i] for i in range(len(grad))]
    # hgp = [t.sum() for t in hgp]
    # a_pow = [-self.rho * (a[i] ** 2) for i in range(len(a))]
    # hgp = [hgp[i] / a_pow[i] for i in range(len(hgp))]

    left_term = (grad.T @ B_grad) / (M * grad_norm**2) 
    right_term = np.sqrt(left_term**2 + 2*grad_norm/M) 
    Rc = -left_term + right_term
    s = -Rc * grad/grad_norm
    del left_term, right_term, Rc, B_grad
    utils.clear_cache()

    print("Doing gradient descent...")
    print("Sigma:", sigma)
    perturb = np.random.randn(*grad.shape)
    perturb = perturb / np.linalg.norm(perturb)  # see (Muller, 1959) and (Marsaglia, 1972)
    grad = grad + sigma * perturb
    del perturb
    utils.clear_cache()
    
    # lr_decay = 0.1
    lr_decay = 0.8
    for step in range(num_steps):
        print(f"GD step {step + 1}, LR = {learning_rate}")
        tuple_s = cubic_func.decompose_param_vector(s, tuple_params)
        B_s = hvp(tuple_hvp_fn, tuple_s, device=device)
        B_s = cubic_func.compose_param_vector(B_s, tuple_params)
        B_s = B_s.detach().cpu()
        B_s = utils.convert_torch_to_numpy(B_s)
        s_norm = np.linalg.norm(s)
        s_grad = (1 - forget_coeff) * (grad + B_s + M/2*s_norm*s) + forget_coeff * (-1) * forget_grad
        s = s - learning_rate * s_grad

        print("descent_grad_norm:", np.linalg.norm(s_grad))
        print("scaled_descent_grad_norm:", np.linalg.norm(learning_rate * s_grad))
        del B_s
        utils.clear_cache()
        learning_rate *= lr_decay
    
    return s


def hvp(compute_hvp: callable, tuple_v, device=None):
    v = tuple(x.to(device) for x in tuple_v)
    res = compute_hvp(v)[1]
    return res

def hvp_func(model, loss_fn, dataloader, device=None):
    tuple_params = tuple(param for param in model.parameters() if param.requires_grad)
    compute_loss_fn = partial(cubic_func.compute_loss, model, loss_fn, dataloader, device)
    res = partial(torch.autograd.functional.hvp, compute_loss_fn, tuple_params)
    return res

