import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline, DebertaV2Config, BertConfig, RobertaConfig
from torch.nn import functional as F
import dataclasses
import torch
from scipy.stats import entropy
from torch.utils.data.dataloader import DataLoader
from transformers.data.data_collator import DefaultDataCollator, DataCollator, InputDataClass
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from typing import List, Union, Dict
from utils import batch
import numpy as np
import torch.nn as nn
import transformers
import logging
from datasets import Dataset
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression as LR

class AvaModel:
    def __init__(self, model_path, label_path=None, cuda=False):
        print(model_path, flush=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        if type(self.model.config) == BertConfig:
            tokenizer_name = "bert-base-cased"
        if type(self.model.config) == RobertaConfig:
            tokenizer_name = "roberta-base"
        elif type(self.model.config) == DebertaV2Config:
            tokenizer_name = "microsoft/deberta-v3-base"
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, kwargs={"model_max_length": 128})
        if type(self.model.config) == RobertaConfig:
            self.model.resize_token_embeddings(len(self.tokenizer))
        self.tokenizer.add_special_tokens({"additional_special_tokens":[f"[unused{k}]" for k in range(1000)]})
        if cuda:
            self.p = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer, device=0)
        else:
            self.p = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
        if type(label_path) == list:
            self.labels = dict(zip(range(len(label_path)), label_path))
        else:
            self.labels = self.model.config.id2label

    def __call__(self, sent):
        model_output = self.p(sent, truncation=True, max_length=128)
        return max(model_output, key=lambda x: x["score"])['label'], model_output[0]["score"]
    
    def ood_mass(self, sent, ood_labels):
        model_output = self.p(sent, return_all_scores=True, truncation=True, max_length=512)[0]
        mass = sum([x["score"] for x in model_output if x["label"] in ood_labels])
        return mass

    def cuda(self):
        self.device = "cuda"
        return self

    def batched_call(self, sents, metric="maxprob", batch_size=5):
        model_outputs = []
        for b in tqdm(batch(sents, n=batch_size)):
            model_outputs.extend(self.p(b, sequential=False, return_all_scores=True, truncation=True, max_length=128))
        if "LABEL" in model_outputs[0][0]:
            labels = [int(max(output, key=lambda x: x["score"])["label"][6:]) for output in model_outputs]
        else: 
            labels = [max(output, key=lambda x: x["score"])["label"] for output in model_outputs]
        if metric == "entropy":
            confidences = [-entropy([k["score"] for k in x]) for x in model_outputs]
        elif metric == "maxprob":
            confidences = [max([p["score"] for p in x]) for x in model_outputs]
        else:
            print("Unknown metric")
        return labels, confidences
    
