#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F

def compute_gradient(model, tokenizer, sample, label, device="cuda"):
    """
    Compute gradient of loss w.r.t. model parameters for one sample.
    """
    model.zero_grad()
    inputs = tokenizer(sample, return_tensors="pt").to(device)
    labels = tokenizer(label, return_tensors="pt")["input_ids"].to(device)

    outputs = model(**inputs, labels=labels)
    loss = outputs.loss
    loss.backward()

    grads = []
    for p in model.parameters():
        if p.grad is not None:
            grads.append(p.grad.detach().flatten())
    return torch.cat(grads)


def gradient_consistency_score(model, tokenizer, pseudo_sample, pseudo_label, ref_set, device="cuda"):
    """
    Compute cosine similarity between gradient of pseudo sample and reference gradient.
    """
    g_pseudo = compute_gradient(model, tokenizer, pseudo_sample, pseudo_label, device)

    # reference gradient: average over S0
    ref_grads = []
    for (x, y) in ref_set:
        g_ref = compute_gradient(model, tokenizer, x, y, device)
        ref_grads.append(g_ref)
    g_ref_mean = torch.stack(ref_grads).mean(dim=0)

    cos_sim = F.cosine_similarity(g_pseudo.unsqueeze(0), g_ref_mean.unsqueeze(0))
    return cos_sim.item()


def tracin_mini_score(model_init, model_curr, tokenizer, pseudo_sample, pseudo_label, val_set, device="cuda"):
    """
    Estimate influence of pseudo sample on validation using TracIn-mini.
    """
    def grad_loss(model, sample, label):
        inputs = tokenizer(sample, return_tensors="pt").to(device)
        labels = tokenizer(label, return_tensors="pt")["input_ids"].to(device)
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        grads = torch.autograd.grad(loss, model.parameters(), retain_graph=False)
        return torch.cat([g.flatten() for g in grads if g is not None])

    # compute pseudo gradients at θ0 and θt
    g0 = grad_loss(model_init, pseudo_sample, pseudo_label)
    gt = grad_loss(model_curr, pseudo_sample, pseudo_label)

    # compute validation gradients
    val_losses = []
    for (x, y) in val_set:
        g_val0 = grad_loss(model_init, x, y)
        g_valt = grad_loss(model_curr, x, y)
        val_losses.append(torch.dot(g0, g_val0) + torch.dot(gt, g_valt))

    return torch.mean(torch.stack(val_losses)).item()
