import torch
from im1k_utils import *
from torch_uncertainty.routines import ClassificationRoutine
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import ImageNetDataModule
import torch.nn as nn

if torch.cuda.is_available():
    torch.set_float32_matmul_precision("medium")
    try:
        torch.backends.cuda.enable_flash_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(True)
    except Exception:
        try:
            torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=False)
        except Exception:
            pass


path = "data"

dm = ImageNetDataModule(
    root=path,
    batch_size=4,
    num_workers=4,
    pin_memory=True,
    eval_ood=True,
)

dm.setup("fit")
val_loader = dm.val_dataloader()
device = torch.device("cuda")


model1 = load_single_model(
            ckpt_path="torchvision://vit_b_16_imagenet1k",
            device=device,
            num_heads_to_prune=4,
            prune_loader=val_loader,
            prune_strategy="gradient", 
        )


model2 = load_single_model(
            ckpt_path="torchvision://vit_b_16_imagenet1k",
            device=device,
            num_heads_to_prune=4,
            prune_loader=val_loader,
            prune_strategy="gradient", 
        )

model3 = load_single_model(
            ckpt_path="torchvision://vit_b_16_imagenet1k",
            device=device,
            num_heads_to_prune=4,
            prune_loader=val_loader,
            prune_strategy="gradient", 
        )
        

hydra = build_hydra_ensemble([model1, model2, model3], gfc_impl="einsum", copy_global_from=0)


trainer = TUTrainer(accelerator="gpu", enable_progress_bar=True, precision="bf16-mixed")

routine = ClassificationRoutine(
    num_classes=1000,
    model=hydra,
    loss=nn.CrossEntropyLoss(),
    eval_ood=True,
)
results = trainer.test(routine,datamodule=dm)
