import torch
import numpy as np
import re
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import Dataset, Subset
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from ..base import BaseNet


class ToPyTorchDataset(Dataset):
    """Convert HuggingFace dataset to PyTorch dataset."""
    def __init__(self, dset, keys=None):
        self.dset = dset
        self.keys = keys

    def __getitem__(self, idx):
        d = self.dset[idx]
        if self.keys is None:
            return d
        return [d[k] for k in self.keys]

    def __len__(self):
        return len(self.dset)


class DistilBert(BaseNet):

    def __init__(self, split='test'):
        self.split = split
        self.hg_model_name = 'bhadresh-savani/distilbert-base-uncased-emotion'
        super().__init__()

    def get_tokenizer(self):
        return AutoTokenizer.from_pretrained(self.hg_model_name)

    def create_model(self):
        return AutoModelForSequenceClassification.from_pretrained(self.hg_model_name)

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.classifier
        model.classifier = torch.nn.Identity()
        return model, last_layer

    def forward_truncated(self, input):
        return self.truncated_model.forward(input)[0]

    def forward_whole(self, input):
        return self.model.forward(input)[0]

    def get_dataset(self):
        dataset = load_dataset('emotion')

        r = re.match(r'(train|validation|test)', self.split)

        if r is None:
            raise ValueError(f'Unknown split {self.split} for {self}.')

        else:
            split = r.group(1)
            ds = dataset[split]

        # Encode selected dataset with selected labels
        tokenizer = self.get_tokenizer()
        def encode(e):
            return tokenizer(e['text'], padding=True, truncation=True)

        ds_encoded = ds.map(lambda e: encode(e), batched=True)
        ds_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
        ds_encoded = ToPyTorchDataset(ds_encoded, keys=['input_ids', 'label'])

        return ds_encoded

    def get_dataset_name(self):
        return f'TwitterEmotion_{self.split}'

    def get_w(self):
        return self.last_layer.weight.detach()

    def get_intercept(self):
        return self.last_layer.bias.detach()

    def logits_to_scores(self, y_logits):
        return torch.softmax(y_logits, dim=1)

    def get_class_names(self):
        return ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
