import torch
from im1k_utils import *
from torch_uncertainty.routines import ClassificationRoutine
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import ImageNetDataModule
from torch_uncertainty.models import deep_ensembles

if torch.cuda.is_available():
    torch.set_float32_matmul_precision("medium")
    try:
        torch.backends.cuda.enable_flash_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(True)
    except Exception:
        try:
            torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=False)
        except Exception:
            pass


path = "data"

dm = ImageNetDataModule(
    root=path,
    batch_size=4,
    num_workers=4,
    pin_memory=True,
)

dm.setup("fit")
val_loader = dm.val_dataloader()
device = torch.device("cuda")

accuracy_circuit = heads_from_log("circuits/acc.json")
ood_circuit = heads_from_log("circuits/ood.json")
average_circuit = heads_from_log("circuits/avg.json")


model1 = load_single_model(
            ckpt_path="torchvision://vit_b_16_imagenet1k",
            device=device,
            num_heads_to_prune=accuracy_circuit,
            prune_loader=val_loader,
            prune_strategy="predefined", 
        )


model2 = load_single_model(
            ckpt_path="torchvision://vit_b_16_imagenet1k",
            device=device,
            num_heads_to_prune=ood_circuit,
            prune_loader=val_loader,
            prune_strategy="predefined", 
        )

model3 = load_single_model(
            ckpt_path="torchvision://vit_b_16_imagenet1k",
            device=device,
            num_heads_to_prune=average_circuit,
            prune_loader=val_loader,
            prune_strategy="predefined", 
        )
        

deep = deep_ensembles([model1, model2, model3])

hydra = build_hydra_ensemble([model1, model2, model3], gfc_impl="einsum", copy_global_from=0)

trainer = TUTrainer(accelerator="gpu",enable_progress_bar=True,precision="bf16-mixed")

routine = ClassificationRoutine(
    num_classes=1000,
    model=hydra,
    loss=nn.CrossEntropyLoss(),
)
results = trainer.test(routine,datamodule=dm)