import torch
from cif100_utils import *
from torch_uncertainty.routines import ClassificationRoutine
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.models import deep_ensembles
from torchvision.models import ViT_B_16_Weights

if torch.cuda.is_available():
    torch.set_float32_matmul_precision("medium")
    try:
        torch.backends.cuda.enable_flash_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(True)
    except Exception:
        try:
            torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=False)
        except Exception:
            pass

path = "data"
weights_tfms = ViT_B_16_Weights.IMAGENET1K_V1.transforms(antialias=True)

dm = CIFAR100DataModule(
    test_transform=weights_tfms,
    root=path,
    batch_size=4,
    num_workers=4,
    eval_ood=True,
)

dm.setup("fit")
val_loader = dm.val_dataloader()
device = torch.device("cuda")


model1 = load_single_model(
            ckpt_path="torchvision://vit_b_16_cifar100",
            device=device,
            num_heads_to_prune=4,
            prune_loader=val_loader,
            prune_strategy="gradient", 
        )


model2 = load_single_model(
            ckpt_path="torchvision://vit_b_16_cifar100",
            device=device,
            num_heads_to_prune=4,
            prune_loader=val_loader,
            prune_strategy="gradient", 
        )

model3 = load_single_model(
            ckpt_path="torchvision://vit_b_16_cifar100",
            device=device,
            num_heads_to_prune=4,
            prune_loader=val_loader,
            prune_strategy="gradient", 
        )
        
hydra = build_hydra_ensemble([model1, model2, model3], gfc_impl="einsum", copy_global_from=0)


trainer = TUTrainer(accelerator="gpu",enable_progress_bar=True,precision="bf16-mixed")

routine = ClassificationRoutine(
    num_classes=100,
    model=hydra,
    loss=nn.CrossEntropyLoss(),
    eval_ood=True,
)
results = trainer.test(routine,datamodule=dm)
