import torch
import torch.nn.functional as F
from torch.autograd import grad, Variable

from HINT.utils import calc_loss

def grad_z(x, y, model, top_model, args, loss_func="cross_entropy"):
    """Calculates the gradient z. One grad_z should be computed for each
    training sample.
    Arguments:
        x: torch tensor, training data points
        y: torch tensor, training data labels
        model: torch NN, model used to evaluate the dataset
        top_model: torch NN, top layer applied to features
    Returns:
        grad_z: list of torch tensor, containing the gradients
    """
    model.eval()
    x, y = x.to(args.device), y.to(args.device)

    # ===== ResNet18 forward pass to extract embedding =====
    x_feat = model.conv1(x)
    x_feat = model.bn1(x_feat)
    x_feat = F.relu(x_feat)
    x_feat = model.layer1(x_feat)
    x_feat = model.layer2(x_feat)
    x_feat = model.layer3(x_feat)
    x_feat = model.layer4(x_feat)
    x_feat = F.adaptive_avg_pool2d(x_feat, (1, 1))
    embedding = torch.flatten(x_feat, 1)
    # ======================================================

    prediction = top_model(embedding)

    loss = calc_loss(prediction, y, loss_func=loss_func)

    return grad(loss, top_model.parameters())

def extract_features(model, x):
    x = model.conv1(x)
    x = model.bn1(x)
    x = F.relu(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    x = F.adaptive_avg_pool2d(x, (1,1))
    x = torch.flatten(x, 1)  # (B, 512, 1, 1) → (B, 512)
    return x

def grad_z_group(val_loader, model, top_model, args):
    model.eval()
    top_model.eval()

    total_loss = None
    for x_val, y_val, _ in val_loader:
        x_val, y_val = x_val.to(args.device), y_val.to(args.device)

        # Feature 추출
        with torch.no_grad():
            features = extract_features(model, x_val)

        # Prediction
        prediction = top_model(features)

        # 개별 loss 저장
        loss_batch = F.cross_entropy(prediction, y_val, reduction='none')
        if total_loss is None:
            total_loss = loss_batch
        else:
            total_loss = torch.cat((total_loss, loss_batch), dim=0)

    # 평균 loss 계산
    loss = total_loss.mean()

    # top_model의 파라미터에 대한 gradient 계산
    return grad(loss, top_model.parameters(), retain_graph=False)



def double_grad_wrt_input(x, y, grad_test_hessian, model, args, top_params, loss_func="cross_entropy"):
    model.eval()

    # initialize
    x, y = x.to(args.device), y.to(args.device)

    var_x = Variable(x.data, requires_grad=True)

    prediction = model(var_x)

    loss = calc_loss(prediction, y, loss_func=loss_func)

    # Compute sum of gradients from model parameters to loss
    grad_theta = grad(loss, tuple(top_params), retain_graph=True, create_graph=True)

    elementwise_products = torch.zeros(1).to(args.device)
    for grad_elem, ih_elem in zip(grad_theta, grad_test_hessian):
        elementwise_products += torch.sum(grad_elem * ih_elem.detach())

    grad_input = grad(elementwise_products, var_x)

    return grad_input


def double_grad_wrt_input_fast(samples, targets, grad_test_hessian, model, args, top_params, loss_func="cross_entropy"):
    model.eval()

    # initialize
    samples, targets = samples.to(args.device), targets.to(args.device)

    samples = torch.stack([torch.tensor(samples[i], requires_grad=True) for i in range(len(samples))], dim=0)

    # top_params_all = torch.stack([torch.tensor(top_params) for i in range(len(samples))], dim=0)
    # print(len(top_params_all))
    # print(top_params_all.shape)
    # prediction_all = model(samples)
    # loss_all = F.cross_entropy(prediction_all, targets, reduction='none')


    influences = []

    for i in range(len(samples)):
        sample = samples[i][None, :]

        prediction = model(sample)

        # Compute sum of gradients from model parameters to loss
        loss = F.cross_entropy(prediction, targets[i].reshape(1))

        grad_theta = grad(loss, top_params, retain_graph=True, create_graph=True)

        elementwise_products = torch.zeros(1).to(args.device)
        for grad_elem, ih_elem in zip(grad_theta, grad_test_hessian):
            elementwise_products += torch.sum(grad_elem * ih_elem.detach())

        grad_out = grad(elementwise_products, samples)

        influences.append(grad_out[0])

    return influences