import torch
import torch.nn as nn
import numpy as np
import os

def estimate_prior(dataset, device):
  """Estimate the class prior for each label."""
  # all_labels = []
  # for i in range(len(dataset)):
  #   _, labels = dataset[i]
  #   all_labels.append(labels)
  # all_labels = torch.stack(all_labels)
  # obs_rate = all_labels.mean(dim=0).clamp(min=1e-4)
  # prior_per_label = torch.clamp(obs_rate * 2.0, min=1e-3, max=0.5).to(device)

  """Use a pre-estimated per-label class prior (constant)."""
  prior_per_label = torch.tensor([
    0.0328, 0.0037, 0.0085, 0.0012, 0.0127, 0.0245, 0.0010, 0.0117
  ], device=device, dtype=torch.float32)
  return prior_per_label
