import wandb
from collections import OrderedDict
import torch

def update_summary(step, train_metrics, eval_metrics):
    rowd = OrderedDict(step=step)
    rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
    rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
    wandb.log(rowd)

def eval(model, eval_dataloader, device):
    model.eval()
    total_loss = 0
    total_acc = 0
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        loss = outputs.loss.mean()
        predictions = torch.argmax(logits, dim=-1)

        # caculate accuracy
        acc = (predictions == batch["labels"]).sum().item() / len(predictions)

        total_loss += loss.item()
        total_acc += acc
    model.train()
    
    return total_loss / len(eval_dataloader), total_acc / len(eval_dataloader)

def eval_regression(model, eval_dataloader, device):
    model.eval()
    total_loss = 0
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        loss = outputs.loss.mean()

        total_loss += loss.item()
    model.train()
    
    return total_loss / len(eval_dataloader)

def clamp_sum(prob, k):
    prob = prob / (prob.sum() + 1e-7) * k
    for i in range(10):
        prob = torch.clamp(prob, min=1e-9, max=1.0)
        prob = prob / prob.sum() * k
    prob = torch.clamp(prob, min=1e-9, max=1.0)
    mask = (torch.rand_like(prob) < prob).float()
    return prob, mask

def importance_sampling_loss_hard(loss, ratio=0.5):
    K = int(ratio * len(loss))
    prob = 1 - torch.exp(-loss)
    prob, mask = clamp_sum(prob, K)
    return (loss / prob).masked_select(mask.bool())

def importance_sampling_loss_easy(loss, ratio=0.5):
    K = int(ratio * len(loss))
    prob = torch.exp(-loss)
    prob, mask = clamp_sum(prob, K)
    return (loss / prob).masked_select(mask.bool())

def importance_sampling_loss_middle(loss, ratio=0.5):
    K = int(ratio * len(loss))
    prob = torch.exp(-loss)
    prob = prob * (1 - prob)
    prob, mask = clamp_sum(prob, K)
    return (loss / prob).masked_select(mask.bool())

def uniform_sampling_loss(loss, ratio=0.5):
    K = int(ratio * len(loss))
    mask = torch.rand_like(loss) < ratio
    return (loss / ratio).masked_select(mask.bool())

def uniform_pruning_loss(loss, ratio=0.5):
    K = int(ratio * len(loss))
    mask = torch.rand_like(loss) < ratio
    return loss.masked_select(mask.bool())