from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from datasets import load_dataset
import torch


class FinancialNewsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)
from torch.nn.utils.rnn import pad_sequence
def tokenize_data(dataset, tokenizer):
    texts = [record['text'] for record in dataset]  
    labels = [record['labels'] for record in dataset] 
    tokenized_texts = [tokenizer.tokenize(text, add_eos=True) for text in texts]
    #encoded_tensors = [tokenizer.convert_to_tensor(tokens)[:128] for tokens in tokenized_texts] # temporal hack for ehud (save memory at inference)
    encoded_tensors = [tokenizer.convert_to_tensor(tokens)[:64] for tokens in tokenized_texts]
    padded_sequences = pad_sequence(encoded_tensors, batch_first=True, padding_value=tokenizer.get_idx("[PAD]"))
    encodings = {'input_ids': padded_sequences}
    #encodings = tokenizer(texts, truncation=True, padding=True, max_length=128)
    return encodings, labels

def create_data_loader(dataset, tokenizer, batch_size=16):
    encodings, labels = tokenize_data(dataset, tokenizer)
    dataset = FinancialNewsDataset(encodings, labels)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader



import pandas as pd

def get_bank_dataloaders(batch_size=16):
    train_dataset = load_dataset("nickmuchi/financial-classification", split='train')  # Adjust split if necessary
    use_augment = False
    if use_augment:
        pd_augment =pd.read_csv("./augmented_dataset.csv")
        from datasets import Dataset
        import datasets
        print(pd_augment.columns)
        # data_dict = pd_augment.to_dict('records')
        # augmented_dataset = Dataset.from_dict(data_dict)
        augmented_dataset = Dataset.from_pandas(pd_augment)
        train_dataset = datasets.concatenate_datasets([train_dataset, augmented_dataset])
    eval_dataset = load_dataset("nickmuchi/financial-classification", split='test')
    #tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    from src.dataloaders.utils.vocabulary import OpenAIVocab, Vocab, BERT_tokenizer
    tokenizer = BERT_tokenizer()
    train_data_loader = create_data_loader(train_dataset, tokenizer,batch_size=batch_size)
    eval_data_loader = create_data_loader(eval_dataset, tokenizer,batch_size=batch_size)
    return train_data_loader,eval_data_loader
