import torch
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from ..base import BaseNet


class SST2Based(BaseNet):

    def __init__(self, hg_model_name=None, split='val'):
        self.split = split
        self.hg_model_name = hg_model_name
        super().__init__()

    def get_tokenizer(self):
        return AutoTokenizer.from_pretrained(self.hg_model_name)

    def create_model(self):
        return AutoModelForSequenceClassification.from_pretrained(self.hg_model_name)

    def forward_truncated(self, input):
        return self.truncated_model.forward(**input, return_dict=False)[0]

    def forward_whole(self, input):
        return self.model.forward(**input, return_dict=False)[0]

    def get_dataset(self):
        split = 'validation' if self.split == 'val' else self.split
        dataset = load_dataset('glue', 'sst2', split=split)
        tokenizer = self.get_tokenizer()
        print(dataset.column_names)
        dataset = dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length'), batched=True)
        columns = set(dataset.column_names)
        columns -= {'sentence', 'idx'}
        dataset.set_format(type='torch', columns=columns)
        return dataset

    def loader_iter_to_input_label(self, iter):
        labels = iter.pop('label')
        return iter, labels

    def get_dataset_name(self):
        return f'GLUE_sst2_{self.split}'

    def get_w(self):
        return self.last_layer.weight.detach()

    def get_intercept(self):
        return self.last_layer.bias.detach()

    def logits_to_scores(self, y_logits):
        return torch.nn.functional.softmax(y_logits, dim=1)

    def get_class_names(self):
        return ['negative', 'positive']


class Roberta(SST2Based):

    def __init__(self, split='val'):
        super().__init__(split=split,
                         hg_model_name='Bhumika/roberta-base-finetuned-sst2')

    def create_model(self):
        return AutoModelForSequenceClassification.from_pretrained(self.hg_model_name)

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.classifier.out_proj
        model.classifier.out_proj = torch.nn.Identity()
        return model, last_layer
