import os
import numpy as np

import torch
import datasets
from datasets import load_dataset, concatenate_datasets, Value

from transformers import BertTokenizerFast, BertModel, BertForSequenceClassification
from transformers import MarianMTModel, MarianTokenizer

def shuffle_and_subset(datasets, szs=(100,100,100)): 
    for k,sz in zip(datasets,szs): 
        datasets[k] = [_shuffle_and_subset(ds,sz) for ds in datasets[k]]

def _shuffle_and_subset(ds, sz): 
    if len(ds) >= sz: 
        return ds.shuffle(seed=0).select(range(sz))
    else: 
        l = []
        while sz > len(ds): 
            l.append(ds)
            sz = sz - len(ds)
        l.append(ds.shuffle(seed=0).select(range(sz)))
        return concatenate_datasets(l)
        
def convert(datasets, f): 
    for k in datasets: 
        datasets[k] = [f(ds) for ds in datasets[k]]

def mix_datasets(datasets, alphas, pop_size=100):
    counts = (alphas*pop_size*1.1).round().long()
    ds = []
    for i in range(alphas.size(0)):
        if counts[i] > 0: 
            ds.append(_shuffle_and_subset(datasets[i],counts[i]))
    return _shuffle_and_subset(concatenate_datasets(ds), pop_size)        
        
def save_encodings(datasets, name): 
    os.makedirs('data',exist_ok=True)
    model = BertModel.from_pretrained("bert-base-cased")
    model.cuda()
    model.eval()
    for k in datasets: 
        for i,ds in enumerate(datasets[k]): 
            ds = ds.map(generic_tokenizer)
            ds.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
            path = os.path.join('data',k,name)
            os.makedirs(path, exist_ok=True)
#             assert len(ds) >= 100
            loader = torch.utils.data.DataLoader(ds, shuffle=True, batch_size=100)
            for batch in loader:
                break
            with torch.no_grad():
                out = model(input_ids=batch['input_ids'].cuda(), attention_mask=batch['attention_mask'].cuda())
                out = out.pooler_output.cpu()
            #torch.save(out, os.path.join(path,f'encoding_{i}.pth'))
            np.save(os.path.join(path,f'encoding_{i}.npy'), out.numpy())

tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') 

def generic_tokenizer(e): 
    return tokenizer(e['text'], padding="max_length", truncation=True)

# SST
# def sst_tokenizer(e): 
#     return tokenizer(e['sentence'], padding="max_length", truncation=True)

# def sst_hf2ch(dataset): 
#     dataset = dataset.map(sst_tokenizer, batched=True)
#     dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
#     return dataset

def threshold(e): 
    e['label'] = int(0) if e['label'] < 0.5 else int(1)
    return e

# create SST sources
def create_sst(): 
    sst_datasets = {
        'sources': [], # 250
        'targets': [], # 500
        'test': [] # 1000
    }

    dataset = load_dataset('sst', 'default', split='train')

    for i in range(10): 
        subset = dataset.filter(lambda e: (e['label'] >= i*0.1) & (e['label'] < (i+1)*0.1))
        sst_datasets['sources'].append(subset)

    # create SST target
    dataset = load_dataset('sst', 'default', split='validation')
    for l,u in ((0,0.5),(0.5,1)): 
        subset = dataset.filter(lambda e: (e['label'] >= l) & (e['label'] < u))
        subset = subset.map(threshold)
        subset = subset.cast_column("label", Value(dtype='int32', id=None))
        sst_datasets['targets'].append(subset)

    # create SST test
    dataset = load_dataset('sst', 'default', split='test')
    for l,u in ((0,0.5),(0.5,1)): 
        subset = dataset.filter(lambda e: (e['label'] >= l) & (e['label'] < u))
        subset = subset.map(threshold)
        subset = subset.cast_column("label", Value(dtype='int32', id=None))
        sst_datasets['test'].append(subset)
        
    for k in sst_datasets: 
        sst_datasets[k] = [ds.rename_column('sentence', 'text') for ds in sst_datasets[k]]
    return sst_datasets

# emoji
# def emoji_tokenizer(e): 
#     return tokenizer(e['text'], padding="max_length", truncation=True)

# def emoji_hf2ch(dataset): 
#     dataset = dataset.map(emoji_tokenizer, batched=True)
#     dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
#     return dataset

def create_emoji(): 
    emoji_datasets = {
        'sources': [], 
        'targets': [], 
        'test': []
    }

    # create emoji sources
    for split,subset_label in [('train','sources'),('validation','targets'),('test','test')]: 
        dataset = load_dataset('tweet_eval', 'emoji', split=split)
        for cls in range(20): 
            dataset = dataset.cast_column("label", Value(dtype='int32', id=None))
            subset = dataset.filter(lambda e: (e['label'] == cls))
            emoji_datasets[subset_label].append(subset)
    return emoji_datasets

# emotion
def create_emotion(): 
    emotion_datasets = {
        'sources': [], 
        'targets': [], 
        'test': []
    }

    # create emoji sources
    for split,subset_label in [('train','sources'),('validation','targets'),('test','test')]: 
        dataset = load_dataset('tweet_eval', 'emotion', split=split)
        for cls in range(4): 
            dataset = dataset.cast_column("label", Value(dtype='int32', id=None))
            subset = dataset.filter(lambda e: (e['label'] == cls))
            emotion_datasets[subset_label].append(subset)
    return emotion_datasets

# yelp
def create_yelp(): 
    datasets = {
        'sources': [], 
        'targets': [], 
        'test': []
    }

    # create emoji sources
    for split,subset_label in [('train[0%:50%]','sources'),('train[50%:100%]','targets'),('test','test')]: 
        dataset = load_dataset('yelp_review_full', 'yelp_review_full', split=split)
        for cls in range(5): 
            dataset = dataset.cast_column("label", Value(dtype='int32', id=None))
            subset = dataset.filter(lambda e: (e['label'] == cls))
            datasets[subset_label].append(subset)
    return datasets

# dydae
def create_dydae(): 
    datasets = {
        'sources': [], 
        'targets': [], 
        'test': []
    }

    # create emoji sources
    for split,subset_label in [('train','sources'),('validation','targets'),('test','test')]: 
        dataset = load_dataset('silicone', 'dyda_e', split=split)
        for cls in range(7): 
            dataset = dataset.cast_column("Label", Value(dtype='int32', id=None))
            subset = dataset.filter(lambda e: (e['Label'] == cls))
            datasets[subset_label].append(subset)

    for k in datasets: 
        datasets[k] = [ds.rename_column('Label', 'label') for ds in datasets[k]]
        datasets[k] = [ds.rename_column('Utterance', 'text') for ds in datasets[k]]
    return datasets


def translate(texts, model, tokenizer, language="fr"):
    # Prepare the text data into appropriate format for the model
    template = lambda text: f"{text}" if language == "en" else f">>{language}<< {text}"
    src_texts = [template(text) for text in texts]

    # Tokenize the texts
    encoded = tokenizer.prepare_seq2seq_batch(src_texts)
    for k in encoded: 
        encoded[k] = torch.Tensor(encoded[k]).long().to(model.device)
    # Generate translation using model
    with torch.no_grad(): 
        translated = model.generate(**encoded)

    # Convert the generated tokens indices back into text
    translated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True)
    
    return translated_texts

def back_translate(texts, en_model, en_tokenizer, target_model, target_tokenizer, source_lang="en", target_lang="fr"):
    # Translate from source to target language
    fr_texts = translate(texts, target_model, target_tokenizer, 
                         language=target_lang)

    # Translate from target language back to source language
    back_translated_texts = translate(fr_texts, en_model, en_tokenizer, 
                                      language=source_lang)
    
    return back_translated_texts

def nlp_augment(ds): 
    target_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
    target_tokenizer = MarianTokenizer.from_pretrained(target_model_name)
    target_model = MarianMTModel.from_pretrained(target_model_name).cuda()

    en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
    en_tokenizer = MarianTokenizer.from_pretrained(en_model_name)
    en_model = MarianMTModel.from_pretrained(en_model_name).cuda()

    def translate_text(example):
        text = example['text']
        en_texts = [text]
        aug_texts = back_translate(en_texts, en_model, en_tokenizer, target_model, target_tokenizer, source_lang="en", target_lang="es")
        example['text'] = aug_texts[0]
        return example

    ds = ds.map(translate_text)
    return ds

create_datasets = {
    "sst": create_sst, 
    "emoji": create_emoji, 
    "emotion": create_emotion, 
    "yelp": create_yelp, 
    "dydae": create_dydae
}

# # rotten tomatoes
# def rt_tokenizer(e): 
#     return tokenizer(e['text'], padding="max_length", truncation=True)

# def rt_hf2ch(dataset): 
#     dataset = dataset.map(rt_tokenizer, batched=True)
#     dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
#     return dataset

# def create_rt(): 
#     rt_datasets = {
#         'sources': [], 
#         'targets': [], 
#         'test': []
#     }

#     # create emoji sources
#     for split,subset_label in [('train','sources'),('validation','targets'),('test','test')]: 
#         dataset = load_dataset('rotten_tomatoes', 'default', split=split)
#         for cls in range(2): 
#             subset = dataset.filter(lambda e: (e['label'] == cls))
#             rt_datasets[subset_label].append(subset)
#     return rt_datasets