"""
Code adapted from Trak examples https://github.com/MadryLab/trak/blob/main/examples/qnli.py

Model: bert-base-cased (https://huggingface.co/bert-base-cased)

Tokenizers and loaders are adapted from the Hugging Face example
(https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification).
"""

from argparse import ArgumentParser
from tqdm import tqdm

import torch 
import torch.nn as nn
import logging
from torch.utils.data import DataLoader

# Huggingface
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    default_data_collator,
)


class SequenceClassificationModel(nn.Module):
    """
    Wrapper for HuggingFace sequence classification models.
    """
    def __init__(self):
        super().__init__()
        self.config = AutoConfig.from_pretrained(
            'bert-base-cased',
            num_labels=2,
            finetuning_task='qnli',
            cache_dir=None,
            revision='main',
            token=None,
        )
        loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
        for logger in loggers:
            if "transformers" in logger.name.lower():
                logger.setLevel(logging.ERROR)


        self.model = AutoModelForSequenceClassification.from_pretrained(
            'bert-base-cased',
            config=self.config,
            cache_dir=None,
            revision='main',
            token=None,
            ignore_mismatched_sizes=False
        )
        # self.model.eval().cuda()

    def forward(self, input_ids, token_type_ids, attention_mask):
        logits = self.model(input_ids=input_ids,
                            token_type_ids=token_type_ids,
                            attention_mask=attention_mask).logits
        return logits #nn.functional.softmax(logits, dim=-1)
    
    def get_data_representation(self, input_ids, token_type_ids, attention_mask):
        outputs = self.model(input_ids=input_ids,
                            token_type_ids=token_type_ids,
                            attention_mask=attention_mask,
                            output_hidden_states=True)
        return outputs.hidden_states[-1][:, 0, :]

def init_model(ckpt_path, device='cuda'):
    model = SequenceClassificationModel()
    sd = torch.load(ckpt_path)
    model.model.load_state_dict(sd)
    return model

model_dict = {
    'bert': SequenceClassificationModel,
}

if __name__ == "__main__":
    def test():
        model = init_model('.')
        y = model(torch.randn(1, 3, 32, 32))
        print(y.size())
    
    print("Running test...")
    test()
    print("Test passed")



# if __name__ == "__main__":
#     parser = ArgumentParser()
#     parser.add_argument('--ckpt', type=str, help='model checkpoint', required=True)
#     parser.add_argument('--out', type=str, help='dir to save TRAK scores and metadata to', required=True)
#     args = parser.parse_args()

#     device = 'cuda'
#     loader_train, loader_val = init_loaders()
#     model = init_model(args.ckpt, device)
