import os
import numpy as np
import cupy as cp
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
from tqdm import *
import matplotlib
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from scipy.optimize import minimize
from scipy import io
import torch
from torch.utils import data
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import copy
from IPython.core.debugger import set_trace
import scipy.io as sio
from itertools import combinations
from scipy.special import gamma
from scipy.special import loggamma
from scipy import stats
from scipy.optimize import minimize
from sklearn import mixture
import math
import random
import argparse
from PIL import Image
import glob
from typing import Any

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(0)
# torch.backends.cuda.matmul.allow_tf32=True
# torch.backends.cudnn.allow_tf32=True
