import torch
from torch import nn
from torch.utils.data import DataLoader
from typing import Callable, Optional


def hvp(
    loss_fn: Callable[[nn.Module], torch.Tensor],
    model: nn.Module,
    vector: torch.Tensor,
) -> torch.Tensor:
    """
    Compute Hessian-vector product via autograd, following Pearlmutter (1994).
    loss_fn should return a scalar loss computed on current model params.
    """
    params = [p for p in model.parameters() if p.requires_grad]
    loss = loss_fn(model)
    grad = torch.autograd.grad(loss, params, create_graph=True)
    flat_grad = torch.cat([g.reshape(-1) for g in grad])
    hvp = torch.autograd.grad(flat_grad, params, grad_outputs=vector, retain_graph=False)
    return torch.cat([h.reshape(-1) for h in hvp])


def conjugate_gradient(
    hvp_fn: Callable[[torch.Tensor], torch.Tensor],
    b: torch.Tensor,
    tol: float = 1e-5,
    max_iter: int = 1000,
) -> torch.Tensor:
    """
    Solve Hx = b for x using conjugate gradient, where hvp_fn(v) = H v.
    """
    x = torch.zeros_like(b)
    r = b.clone()
    p = r.clone()
    rs_old = torch.dot(r, r)
    for _ in range(max_iter):
        Hp = hvp_fn(p)
        alpha = rs_old / (torch.dot(p, Hp) + 1e-12)
        x = x + alpha * p
        r = r - alpha * Hp
        rs_new = torch.dot(r, r)
        if torch.sqrt(rs_new) < tol:
            break
        p = r + (rs_new / rs_old) * p
        rs_old = rs_new
    return x


def flatten_params(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.reshape(-1) for p in model.parameters() if p.requires_grad])


def set_params_from_flat(model: nn.Module, flat: torch.Tensor) -> None:
    idx = 0
    for p in model.parameters():
        if not p.requires_grad:
            continue
        numel = p.numel()
        p.data.copy_(flat[idx: idx + numel].view_as(p))
        idx += numel


def grad_z(
    model: nn.Module,
    loss_per_sample: torch.Tensor,
    sample_idx: int,
) -> torch.Tensor:
    """
    Compute gradient of loss of a single sample wrt params, flattened.
    """
    params = [p for p in model.parameters() if p.requires_grad]
    loss = loss_per_sample[sample_idx]
    grads = torch.autograd.grad(loss, params, retain_graph=True)
    return torch.cat([g.reshape(-1) for g in grads])


def s_test(
    model: nn.Module,
    loss_fn: Callable[[nn.Module], torch.Tensor],
    v: torch.Tensor,
    damp: float = 0.01,
    scale: float = 25.0,
    cg_tol: float = 1e-5,
    cg_max_iter: int = 1000,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Compute s_test = H^{-1} v using CG and HVP, as in Koh & Liang (2017, Sec. 3).
    We approximate H with damped Hessian: H + damp * I, and optionally scale.
    """
    def hvp_damped(vec: torch.Tensor) -> torch.Tensor:
        vec_m = vec * mask if mask is not None else vec
        Hv = hvp(loss_fn, model, vec_m)
        if mask is not None:
            Hv = Hv * mask
        return (Hv + damp * vec_m) / scale
    b = (v * mask if mask is not None else v) / scale
    x = conjugate_gradient(hvp_damped, b, tol=cg_tol, max_iter=cg_max_iter)
    return x


def influence_on_test_point(
    model: nn.Module,
    train_loss_per_sample: torch.Tensor,
    test_loss_grad: torch.Tensor,
    loss_fn_for_hvp: Callable[[nn.Module], torch.Tensor],
    damp: float = 0.01,
    scale: float = 25.0,
) -> torch.Tensor:
    """
    Compute influence values for all training samples on a test point.
    Returns vector of influences (negative inner product with s_test).
    """
    s = s_test(model, loss_fn_for_hvp, test_loss_grad, damp=damp, scale=scale)
    influences = []
    for i in range(train_loss_per_sample.shape[0]):
        gz = torch.autograd.grad(train_loss_per_sample[i], [p for p in model.parameters() if p.requires_grad], retain_graph=True)
        gz_flat = torch.cat([g.reshape(-1) for g in gz])
        influences.append(-torch.dot(gz_flat, s).detach())
    return torch.stack(influences)


