import torch
import numpy as np
import re
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import Dataset, Subset
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from ..base import BaseNet


class ToPyTorchDataset(Dataset):
    """Convert HuggingFace dataset to PyTorch dataset."""
    def __init__(self, dset, keys=None):
        self.dset = dset
        self.keys = keys

    def __getitem__(self, idx):
        d = self.dset[idx]
        if self.keys is None:
            return d
        return [d[k] for k in self.keys]

    def __len__(self):
        return len(self.dset)


class BertBase(BaseNet):

    def __init__(self, split='test'):
        self.split = split
        self.hg_model_name = 'nlptown/bert-base-multilingual-uncased-sentiment'
        super().__init__()

    def get_tokenizer(self):
        return AutoTokenizer.from_pretrained(self.hg_model_name)

    def create_model(self):
        return AutoModelForSequenceClassification.from_pretrained(self.hg_model_name)

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.classifier
        model.classifier = torch.nn.Identity()
        return model, last_layer

    def forward_truncated(self, input):
        input_ids, attention_mask, token_type_ids = input
        return self.truncated_model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            )[0]

    def forward_whole(self, input):
        input_ids, attention_mask, token_type_ids = input
        return self.model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            )[0]

    def loader_iter_to_input_label(self, input):
        label = input[-1]
        input = input[:-1]
        return input, label

    def get_dataset(self):
        r = re.match(r'(train|validation|test)(_(de|en|es|fr|ja|zh|known|unknown))?(_title)?', self.split)

        if r is None:
            raise ValueError(f'Unknown split {self.split} for {self}.')

        split = r.group(1)
        language = r.group(3)
        review_input = 'review_body' if r.group(4) is None else 'review_title'
        print(f'Review input: {review_input}.')

        if language is None:
            language = 'all_languages'

        if language == 'known':
            # Languages known by the model
            languages = ['en', 'fr', 'de', 'es']

        elif language == 'unknown':
            # Languages unknown by the model
            languages = ['ja', 'zh']

        else:
            languages = [language]

        L_ds = []
        for language in languages:
            dataset = load_dataset("amazon_reviews_multi", language)
            ds = dataset[split]
            L_ds.append(ds)
        ds = concatenate_datasets(L_ds)

        # Encode selected dataset with selected labels
        tokenizer = self.get_tokenizer()
        def encode(e):
            d = tokenizer(e[review_input], padding=True, truncation=True)
            star_to_label = {
                1: 0,
                2: 1,
                3: 2,
                4: 3,
                5: 4,
            }
            d['label'] = [star_to_label[s] for s in e['stars']]
            return d

        ds_encoded = ds.map(lambda e: encode(e), batched=True)
        keys = ["input_ids", "attention_mask", "token_type_ids", "label"]
        ds_encoded.set_format("torch", columns=keys)
        ds_encoded = ToPyTorchDataset(ds_encoded, keys=keys)

        return ds_encoded

    def get_dataset_name(self):
        return f'AmazonReviewMulti_{self.split}'

    def get_w(self):
        return self.last_layer.weight.detach()

    def get_intercept(self):
        return self.last_layer.bias.detach()

    def logits_to_scores(self, y_logits):
        return torch.softmax(y_logits, dim=1)

    def get_class_names(self):
        return [
            '1 star',
            '2 stars',
            '3 stars',
            '4 stars',
            '5 stars',
        ]
