import numpy as np
import torch
from torch.autograd.functional import vhp
from torch.utils.data import Subset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm

from HINT.grad_functions import grad_z, grad_z_group
from HINT.utils import make_functional, load_weights, calc_loss

#####################################################################################
from torch.autograd import grad
def compute_test_imapct(model, testloader, device):
    # Impact on test outputs
    model.eval()
    params = list(model.parameters())
    n_total = 0
    jacobian_sums = [torch.zeros_like(p) for p in params]

    for i, (x, _, _) in enumerate(tqdm(testloader, desc="Computing Test Jacobians")):
        x = x.to(device)
        model.zero_grad()
        out = model(x) # [128, 10]
        scalar = out.abs().sum(dim=1).mean() # batch * class -> absolute -> sum -> mean
        grads = grad(scalar, params, retain_graph = False, allow_unused = False)
        
        B = x.size(0) # 128

        for i, g in enumerate(grads):
            jacobian_sums[i].add_(g.detach(), alpha=B)

        n_total += B

    jacobian_mean = [j_sum / n_total for j_sum in jacobian_sums]
    return jacobian_mean
#####################################################################################

def s_test_sample(
        model,
        top_model,
        x_test,
        y_test,
        train_loader,
        args,
        loss_func="cross_entropy",
):
    """Calculates s_test for a single test image taking into account the whole
    training dataset. s_test = invHessian * nabla(Loss(test_img, model params))
    Arguments:
        model: pytorch model, for which s_test should be calculated
        x_test: test image
        y_test: test image label
        train_loader: pytorch dataloader, which can load the train data
        gpu: int, device id to use for GPU, -1 for CPU (default)
        damp: float, influence function damping factor
        scale: float, influence calculation scaling factor
        recursion_depth: int, number of recursions to perform during s_test
            calculation, increases accuracy. r*recursion_depth should equal the
            training dataset size.
        r: int, number of iterations of which to take the avg.
            of the h_estimate calculation; r*recursion_depth should equal the
            training dataset size.
    Returns:
        s_test_vec: torch tensor, contains s_test for a single test image"""

    inverse_hvp = [
        torch.zeros_like(params, dtype=torch.float) for params in model.parameters()
    ]

    for i in range(args.r_average):
        hessian_loader = DataLoader(
            train_loader.dataset,
            sampler=torch.utils.data.RandomSampler(
                train_loader.dataset, True, num_samples=args.recur_depth
            ),
            batch_size=args.hvp_batch_size,
            num_workers=4,
        )

        cur_estimate = s_test(
            x_test, y_test, model, top_model, i, hessian_loader, args, loss_func=loss_func,
        )

        with torch.no_grad():
            inverse_hvp = [
                old + (cur / args.scale) for old, cur in zip(inverse_hvp, cur_estimate)
            ]

    with torch.no_grad():
        inverse_hvp = [component / args.r_average for component in inverse_hvp]

    return inverse_hvp


def s_test_group_sample(
        model,
        top_model,
        val_loader,
        train_loader,
        args,
        loss_func="cross_entropy",
):
    """Calculates s_test for a single test image taking into account the whole
    training dataset. s_test = invHessian * nabla(Loss(test_img, model params))
    Arguments:
        model: pytorch model, for which s_test should be calculated
        x_test: test image
        y_test: test image label
        train_loader: pytorch dataloader, which can load the train data
        gpu: int, device id to use for GPU, -1 for CPU (default)
        damp: float, influence function damping factor
        scale: float, influence calculation scaling factor
        recursion_depth: int, number of recursions to perform during s_test
            calculation, increases accuracy. r*recursion_depth should equal the
            training dataset size.
        r: int, number of iterations of which to take the avg.
            of the h_estimate calculation; r*recursion_depth should equal the
            training dataset size.
    Returns:
        s_test_vec: torch tensor, contains s_test for a single test image"""

    inverse_hvp = [
        torch.zeros_like(params, dtype=torch.float) for params in model.parameters()
    ] # top_model to model

    for i in range(args.r_average):
        hessian_loader = DataLoader(
            train_loader.dataset,
            sampler=torch.utils.data.RandomSampler(
                train_loader.dataset, True, num_samples=args.recur_depth
            ),
            batch_size=args.hvp_batch_size,
            num_workers=4,
        )

        cur_estimate = s_test_group(
            val_loader, model, top_model, i, hessian_loader, args, loss_func=loss_func,
        )

        with torch.no_grad():
            inverse_hvp = [
                old + (cur / args.scale) for old, cur in zip(inverse_hvp, cur_estimate)
            ]

    with torch.no_grad():
        inverse_hvp = [component / args.r_average for component in inverse_hvp]

    return inverse_hvp

def extract_embedding(model, x):
    with torch.no_grad():
        x_feat = model.conv1(x)
        x_feat = model.bn1(x_feat)
        x_feat = F.relu(x_feat)
        x_feat = model.layer1(x_feat)
        x_feat = model.layer2(x_feat)
        x_feat = model.layer3(x_feat)
        x_feat = model.layer4(x_feat)
        x_feat = F.adaptive_avg_pool2d(x_feat, (1, 1))
        embedding = torch.flatten(x_feat, 1)
    return embedding

def s_test(x_test, y_test, model, top_model, i, samples_loader, args, loss_func="cross_entropy"):

    # v = grad_z(x_test, y_test, model, top_model, args, loss_func=loss_func) 
    model.eval()
    x_test, y_test = x_test.to(args.devcie), y_test.to(args.device)
    prediction = model(x_test)
    loss = calc_loss(prediction, y_test, loss_func=loss_func)
    v = grad(loss, model.parameters(), retain_graph=False, allow_unused=False)
    h_estimate = v 

    params, names = make_functional(model) # top_model to model for using all model parameter
    # Make params regular Tensors instead of nn.Parameter
    params = tuple(p.detach().requires_grad_() for p in params)

    # TODO: Dynamically set the recursion depth so that iterations stop once h_estimate stabilises
    progress_bar = tqdm(samples_loader, desc=f"IHVP sample {i}")
    for i, (x_train, y_train, _, _) in enumerate(progress_bar):

        x_train, y_train = x_train.to(args.device), y_train.to(args.device)

        def f(*new_params):
            load_weights(model, names, new_params) # top_model to model

            out = model(x_train)
            
            loss = calc_loss(out, y_train, loss_func=loss_func)
            return loss

        hv = vhp(f, params, tuple(h_estimate), strict=False)[1]

        # Recursively calculate h_estimate
        with torch.no_grad():
            h_estimate = [
                _v + (1 - args.damp) * _h_e - _hv / args.scale
                for _v, _h_e, _hv in zip(v, h_estimate, hv)
            ]

            if i % 100 == 0:
                norm = sum([h_.norm() for h_ in h_estimate])
                progress_bar.set_postfix({"est_norm": norm.item()})

    with torch.no_grad():
        load_weights(model, names, params, as_params=True) # top_model to model

    return h_estimate

def s_test_group(val_loader, model, top_model, i, samples_loader, args,
                 loss_func="cross_entropy"):
    """s_test can be precomputed for each test point of interest, and then
    multiplied with grad_z to get the desired value for each training point.
    Here, stochastic estimation is used to calculate s_test. s_test is the
    Inverse Hessian Vector Product.
    Arguments:
        x_test: torch tensor, test data points, such as test images
        y_test: torch tensor, contains all test data labels
        model: torch NN, model used to evaluate the dataset
        i: the sample number
        samples_loader: torch DataLoader, can load the training dataset
        gpu: int, GPU id to use if >=0 and -1 means use CPU
        damp: float, dampening factor
        scale: float, scaling factor
    Returns:
        h_estimate: list of torch tensors, s_test"""

    # v = grad_z_group(val_loader, model, top_model, args)
    v = compute_test_imapct(model, val_loader, device=args.device)
    h_estimate = v

    params, names = make_functional(model) # top_model to model
    # Make params regular Tensors instead of nn.Parameter
    params = tuple(p.detach().requires_grad_() for p in params)

    # TODO: Dynamically set the recursion depth so that iterations stop once h_estimate stabilises
    progress_bar = tqdm(samples_loader, desc=f"IHVP sample {i}")
    for i, (x_train, y_train, _, _) in enumerate(progress_bar):

        x_train, y_train = x_train.to(args.device), y_train.to(args.device)

        def f(*new_params):
            load_weights(model, names, new_params) # top_model to model
            out = model(x_train)
            loss = calc_loss(out, y_train, loss_func=loss_func)
            return loss

        hv = vhp(f, params, tuple(h_estimate), strict=True)[1]

        # Recursively calculate h_estimate
        with torch.no_grad():
            h_estimate = [
                _v + (1 - args.damp) * _h_e - _hv / args.scale
                for _v, _h_e, _hv in zip(v, h_estimate, hv)
            ]

            if i % 50 == 0:
                norm = sum([h_.norm() for h_ in h_estimate])
                progress_bar.set_postfix({"est_norm": norm.item()})

    with torch.no_grad():
        load_weights(model, names, params, as_params=True) # top_model to model

    return h_estimate