from prune_bert import *
from torch_uncertainty.datamodules import Sst2DataModule
from torch_uncertainty.routines import ClassificationRoutine
from torch_uncertainty import TUTrainer
import torch
import torch.nn as nn
from collections import OrderedDict
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForSequenceClassification


class HFClassifier(nn.Module):
    def __init__(self, model_name: str, num_labels: int = 2, local_files_only: bool = False):
        super().__init__()
        self.backbone = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=num_labels, local_files_only=local_files_only
        )
    def forward(self, *args, **kwargs):
        inputs = args[0] if (len(args)==1 and isinstance(args[0], dict)) else kwargs
        return self.backbone(**inputs).logits

dm = Sst2DataModule(
    batch_size=64,
    eval_ood=True,  
)
dm.prepare_data()
dm.setup("fit")


net1 = HFClassifier("bert-base-uncased", num_labels=2)


net1.backbone = structurally_prune_attention_heads_bert(
        net1.backbone,
        num_heads_to_prune={8: [0, 2, 6, 8, 10, 11], 9: [0, 4, 5, 6, 7, 8, 9], 10: [0, 1, 4, 5, 6, 7, 8, 9, 10, 11], 11: [0, 2, 3, 4, 5, 6, 7, 8, 9, 11], 6: [1, 2, 3, 9, 11], 4: [1, 3, 4, 6, 9], 5: [2, 5, 6, 9], 7: [0, 4, 6, 7, 9, 10, 11], 2: [0, 2, 3, 5, 7, 8], 0: [4, 5, 6, 7, 11], 3: [4, 5, 7, 8], 1: [3, 5, 9]},      
        strategy="predefined",          
        context="layer",
        dataloader=None,               
        device=torch.device("cuda"),
        verbose=True,
        stochastic=False,
        seed=42,
        use_bias=True,            
    )


trainer = TUTrainer(accelerator="gpu",enable_progress_bar=True,devices=1)

routine = ClassificationRoutine(
    num_classes=2,
    model=net1,
    loss=nn.CrossEntropyLoss(),
    eval_ood=True,
)

res = trainer.test(routine, datamodule=dm)