from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import typo
from dataset import * # data pre-processing
import nltk
from nltk.corpus import stopwords
from nltk.corpus import wordnet
import numpy as np
import random
# from textattack.augmentation import EasyDataAugmenter

############ Utils

def concat_tab_txt(data, cat_num_var_list, text_var):
  """Concatenate categorical, numerical and text variables into new_text_var."""
  df = data.copy()
  df["num_cat_text"] = ""
  for var in cat_num_var_list:
      df["num_cat_text"] += df.apply(lambda row : var + " " + str(row[var]) + " ", axis=1)
  df[text_var] = df["num_cat_text"] + df[text_var]
  df = df.drop('num_cat_text', axis = 1)
  return df

############ Tabular shift

def orderSplit(data, variable, seed, target_size, sample_rate):
    """
    Sorts the dataset by the value of selected variable, splits the sorted dataset in three sections S_low (5%), S_middle (5-95%) and S_high (95-100%).
    The Target is constructed by sampling (with replacement) from S_low ∪ S_high and from S_middle with specific sampling rates.
a
    Args:
      data (pandas dataframe): Target dataset
      variable (str): variable used to sort and split the dataset
      seed (int): seed used for random sampling
      target_size (int): number of samples for Target dataset
      sample_rate (float): intensity of shift (sampling from S_low U S_high)

    Returns:
      (pandas dataframes) Target dataset shifted based on the described scheme

    Inspired from Reference:
      Mougan et al., AAAI-2023
      "Monitoring Model Deterioration with Explainable Uncertainty Estimation via Non-parametric Bootstrap"

    Example:
    target = orderSplit(target, variable="Age", seed=0, target_size=1000, sample_rate=0.9)
    """
    dataset = data.copy()

    # sort based on variable values
    dataset = dataset.sort_values(variable)

    # split into 3 sections S_low (5%), S_middle (90%), S_high (5%)
    dataset_length = len(dataset)
    size = int(0.05 * dataset_length)
    S_low = dataset[:size]
    S_high = dataset[dataset_length - size :]
    S_low_high = pd.concat((S_low, S_high))
    S_middle = dataset[size:dataset_length - size ]

    # sample with replacement to construct final dataset
    S1_size = int(sample_rate * target_size)
    S2_size = target_size - S1_size
    S1 = S_low_high.sample(n = S1_size, random_state = seed, replace = True)
    S2 = S_middle.sample(n = S2_size, random_state = seed, replace = True)
    target_data = pd.concat((S1,S2)).reset_index(drop=True)

    return target_data
  
def emptyCategory(data, variables, threshold, target_size, seed):
    """
    Replace the values of categorical variables with empty values.
    The shift is applied to a proportion of rows ("threshold") and a random number of random categorical variables in each row.

    Args:
      data (pandas dataframe): original dataset with features and target variable.
      variables (list of str): list of categorical variables.
      threshold (float): proportion of data where shift is applied. (intensity of shift)
      target_size (int): number of samples for Target dataset.
      seed (int): seed for random choice of data to be modified.

    Returns:
      (pandas dataframe) dataset with irrelevant beginning of sentences.
    
    Example:
    target = emptyCategory(target, variables = categorical_var, threshold = 0.9, target_size = 1000, seed = 0)
  
    """

    dataset = data.copy()
    _, dataset = train_test_split(dataset, test_size = target_size, random_state = seed)
    dataset = dataset.reset_index(drop=True)

    for iter in range(len(dataset)): # for each row...
        np.random.seed(seed + iter)
        toss = np.random.rand()
        if toss < threshold: # ... apply shift with probability = threshold, to a random number of categorical variables
            n = np.random.randint(1, len(variables)+1)
            cat_vars = np.random.choice(variables, size=n)
            for var in cat_vars:
                dataset.loc[iter, var]= " "

    return dataset
    
############ Text shift

def typos(data, variable, num_typos, target_size, seed):
  """
  Include a given number of typos in the text fields.
  There are six types of random typos: char swap, remove char, add char, replace char, remove space, add space.

  Args:
    data (pandas dataframe): original Target dataset
    variable (str): text variable where typos will be added
    num_typos (int): number of typos applied to each text field (intensity of shift)
    target_size (int): number of samples for Target dataset
    seed (int): seed for random choice of typos (types) and affected words

  Returns:
    (pandas dataframe) Target dataset with typos in text fields

  Reference: typo package https://github.com/ranvijaykumar/typo
  
  Example:
  target = typos(target, variable = text_var, num_typos = 50, target_size = 1000, seed = 0)
  """
  dataset = data.copy()
  _, dataset = train_test_split(dataset, test_size = target_size, random_state = seed)

  # randomly extract num_typos
  typo_types = ["char_swap", "missing_char", "extra_char", "nearby_char", "skipped_space", "random_space"]
  np.random.seed(seed)
  typo_choices = np.random.choice(typo_types, size=num_typos)

  # implement typos in text field
  iter = seed
  for error in typo_choices:
    if error == "char_swap":
      dataset[variable] = dataset[variable].apply(lambda row: typo.StrErrer(row, seed=iter).char_swap().result)
    if error == "missing_char":
      dataset[variable] = dataset[variable].apply(lambda row: typo.StrErrer(row, seed=iter).missing_char().result)
    if error == "extra_char":
      dataset[variable] = dataset[variable].apply(lambda row: typo.StrErrer(row, seed=iter).extra_char().result)
    if error == "nearby_char":
      dataset[variable] = dataset[variable].apply(lambda row: typo.StrErrer(row, seed=iter).nearby_char().result)
    if error == "skipped_space":
      dataset[variable] = dataset[variable].apply(lambda row: typo.StrErrer(row, seed=iter).skipped_space().result)
    if error == "random_space":
      dataset[variable] = dataset[variable].apply(lambda row: typo.StrErrer(row, seed=iter).random_space().result)
    iter+=1

  dataset = dataset.reset_index(drop=True)

  return dataset
  
  
def seqLengthSplit(data, variable, seed, target_size, sample_rate, ascending):
    """
    Sorts the dataset by length of sequence for the selected text variable, splits the sorted dataset in two sections: S_low (10%) and S_high (90%).
    The Target is constructed by sampling (with replacement) from S_low and from S_high with specific sampling rates.

    Args:
      data (pandas dataframe): Target dataset
      variable (str): variable used to sort and split the dataset
      seed (int): seed used for random sampling
      target_size (int): number of samples for Target dataset
      sample_rate (float): sampling rate for S_low (intensity of shift)
      ascending (bool): ascending order

    Returns:
      (pandas dataframes) Target dataset shifted based on the described scheme

    Example:
    target = seqLengthSplit(data = target, variable = text_var, seed = 0, target_size = 1000, sample_rate = 0.9, ascending = False)
    """
    dataset = data.copy()

    # sort values based on text sequence length
    dataset["seq_length"] = dataset.apply(lambda row : len(row[variable].split()), axis=1)
    dataset = dataset.sort_values("seq_length", ascending = ascending)
    dataset = dataset.reset_index(drop=True)

    # split into 2 sections S_low (10%), S_high (90%)
    dataset_length = len(dataset)
    size = int(0.1 * dataset_length)
    S_low = dataset[:size]
    S_high = dataset[size :]

    # sample with replacement to construct final dataset
    S1_size = int(sample_rate * target_size)
    S2_size = target_size - S1_size
    S1 = S_low.sample(n = S1_size, random_state = seed, replace = True)
    S2 = S_high.sample(n = S2_size, random_state = seed, replace = True)
    target_data = pd.concat((S1,S2)).reset_index(drop=True)

    return target_data
  
def cutText(data, variable, cut_proportion, threshold, target_size, seed):
  """
  Cut end of text field.

  Args:
    data (pandas dataframe): original dataset with features and target variable.
    variable (str): text variable.
    cut_proportion (float): proportion of text to cut (from beginning of text) (intensity of shift)
    threshold (float): proportion of data where an irrelevant text is preprended. (intensity of shift)
    target_size (int): number of samples for Target dataset
    seed (int): seed for random choice of data to be modified.

  Returns:
    (pandas dataframe) dataset with irrelevant beginning of sentences.
    
  Example:
  target = cutText(target, variable = text_var, cut_proportion = 0.5, threshold = 0.9, target_size = 1000, seed = 0)
  
  """

  dataset = data.copy()
  _, dataset = train_test_split(dataset, test_size = target_size, random_state = seed)

  # cut beginning of text field
  sentence_cuts = []
  iter = 0
  for sentence in dataset[variable].values:
      np.random.seed(seed + iter)
      toss = np.random.rand()
      iter+=1
      if toss >= threshold: # does not change the text
          sentence_cuts.append(sentence)
      else: # cut the text
          sentence_length = len(sentence)
          new_sentence = sentence[:int(sentence_length*(1-cut_proportion))]
          sentence_cuts.append(new_sentence)

  # update the text field in the dataset
  dataset[variable] = sentence_cuts

  return dataset
  
  
def dataAugment(data, variable, threshold, augment_proportion, pct_words_to_swap, target_size, seed):
  """
  Data augmentation from Wei et al., 2019: synonym replacement, random insertion, random swap, and random deletion.

  Args:
    data (pandas dataframe): original Target dataset
    variable (str): text variable where typos will be added
    threshold (float): proportion of data that is augmented in terms of number of dataset rows (intensity of shift)
    augment_proportion (float): proportion of data that is augmented in terms of sentence length (intensity of shift)
    pct_words_to_swap (float): probability that a word is attacked (intensity of shift)
    target_size (int): number of samples for Target dataset
    seed (int): seed for random choice of typos (types) and affected words

  Returns:
    (pandas dataframe) Target dataset with augmented text fields

  Reference: EDA: Easy Data Augmentation Techniques for Boosting Performance on Text Classification Tasks, EMNLP 2019
  
  Example:
  target = dataAugment(target, variable = text_var, pct_words_to_swap = 0.3, threshold = 0.9, augment_proportion = 0.3, target_size = 1000, seed = 0)
  """


  dataset = data.copy()
  _, dataset = train_test_split(dataset, test_size = target_size, random_state = seed)

  random.seed(seed)
  EDA_augmenter = EasyDataAugmenter(pct_words_to_swap=pct_words_to_swap, transformations_per_example=1) 
  sentence_augment = []
  iter = 0
  for sentence in dataset[variable].values:
      np.random.seed(seed + iter)
      toss = np.random.rand()
      iter+=1
      if toss >= threshold: # does not augment text
          sentence_augment.append(sentence)
      else:
          sentence_length = len(sentence)
          text1 = EDA_augmenter.augment(sentence[:int(sentence_length*augment_proportion)])[0]
          text2 = sentence[int(sentence_length*augment_proportion):]
          text = text1 + text2
          sentence_augment.append(text)

  # update the text field in the dataset
  dataset[variable] = sentence_augment

  return dataset
  
def abbrev(data, variable, threshold, target_size, seed):
      """
      Replace words by corresponding abbreviations.
    
      Args:
        data (pandas dataframe): original dataset with features and target variable.
        variable (str): text variable.
        threshold (float): proportion of rows that are modified. (intensity of shift)
        target_size (int): number of samples for Target dataset
        seed (int): seed for random choice of data to be modified.
    
      Returns:
        (pandas dataframe) dataset with abbreviations.
        
      Example:
      target = abbrev(target, variable=text_var, threshold = 0.9, target_size=1000, seed=0)
      
      """
    
      dataset = data.copy()
      _, dataset = train_test_split(dataset, test_size = target_size, random_state = seed)
      dataset = dataset.reset_index(drop=True)
    
      # load abbreviation dictionary
      abbrev_df = pd.read_csv("datasets/abbreviations.csv", sep=";")
      abbrev_dict = abbrev_df.set_index('WORD').to_dict()['ABBREVIATION']
    
      # modify one proportion
      dataset1, dataset2 = train_test_split(dataset, train_size = int(threshold*target_size), random_state = seed)
    
      # modify words into corresponding abbreviations
      dataset1[variable] = dataset1[variable].apply(lambda row : row.lower())
      for word, abbrev in abbrev_dict.items():
        dataset1[variable] = dataset1[variable].apply(lambda row : row.replace(word,abbrev))
    
      # concatenate both datasets
      dataset = pd.concat((dataset1, dataset2))
    
      return dataset
  
############ Unseen label
  
def newClass(original_data, label_name, text_var, new_dataset_name, model_type, seed, target_size, sample_rate_new):
    """
    Includes in Target data a proportion of samples with unseen label.

    Args:
      original_data (pandas dataframe): Target dataset
      label_name (str): label name
      text_var (str): name of text variable
      new_dataset_name (str): name of dataset with unseen labels
      model_type (str): model type
      seed (int): seed used for random sampling
      target_size (int): number of samples for Target dataset
      sample_rate_new (float): intensity of shift (sampling from the new dataset)
    Returns:
      (pandas dataframes) Target dataset based on the described scheme

    Example:
    target = newClass(target, label_name="Variety", text_var = 'Variety' ,new_dataset_name = "wine_100", model_type = "LateFuseBERT", seed=0, target_size=1000, sample_rate_new=0.9)
    """

    dataset_original = original_data.copy()

    # load dataset with new classes
    df_new = preprocess_dataset(new_dataset_name, model_type)
    if model_type == "AllTextBERT" :
        # reset text field
        df_new[text_var] = df_new[text_var+'_original'].copy()
    # keep only new classes
    dataset_new = df_new[~df_new[label_name].isin([lab for lab in dataset_original[label_name].unique()])]
    # label encoding of target variable
    le = LabelEncoder()
    dataset_new['Y'] = le.fit_transform(dataset_new['Y']) + dataset_original[label_name].nunique()

    # sample with replacement to construct final dataset
    S_new_size = int(sample_rate_new*target_size)
    S_original_size = target_size - S_new_size
    S_new = dataset_new.sample(n = S_new_size, random_state = seed, replace = True)
    S_original = dataset_original.sample(n = S_original_size, random_state = seed, replace = True)
    target_data = pd.concat((S_new,S_original)).reset_index(drop=True)

    return target_data