import pandas as pd
# import transformers
# from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from transformers import BertTokenizer

import torch

import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap

from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

def get_imdb_dataset(max_len):
    df = pd.read_csv('./data/imdb/IMDB-Dataset.csv')
    def to_sentiment(rating):
      rating = str(rating)
      if rating == 'positive':
        return 0
      else: 
        return 1
    
    df['sentiment_score'] = df.sentiment.apply(to_sentiment)
    
    PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
    
    tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
    
    
    # token_lens = []
    
    # for txt in df.review:
    #   tokens = tokenizer.encode(txt, max_length=512, truncation=True)
    #   token_lens.append(len(tokens))
    
    
    MAX_LEN = max_len
    
    class MovieReviewDataset(Dataset):
    
      def __init__(self, reviews, targets, tokenizer, max_len):
        self.reviews = reviews
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len
      
      def __len__(self):
        return len(self.reviews)
      
      def __getitem__(self, item):
        review = str(self.reviews[item])
        target = self.targets[item]
    
        encoding = self.tokenizer.encode_plus(
          review,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          pad_to_max_length=True,
          return_attention_mask=True,
          return_tensors='pt',
          truncation = True
        )
    
        return {
          'review_text': review,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'targets': torch.tensor(target, dtype=torch.long)
        }
    
    df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)
    
    
    ds_train = MovieReviewDataset(
        reviews=df_train.review.to_numpy(),
        targets=df_train.sentiment_score.to_numpy(),
        tokenizer=tokenizer,
        max_len=MAX_LEN
        )

    ds_test = MovieReviewDataset(
        reviews=df_test.review.to_numpy(),
        targets=df_test.sentiment_score.to_numpy(),
        tokenizer=tokenizer,
        max_len=MAX_LEN
        )

    return ds_train, ds_test, tokenizer.vocab_size
