from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd


def get_dataset(tokenizer, type_path, args):
  return MolDataset(tokenizer=tokenizer, data_dir=args.data_dir, type_path=type_path,  max_len=args.max_seq_length)


class MolDataset(Dataset):
  def __init__(self, tokenizer, data_dir, type_path,  max_len=512):
    
    self.file_path = os.path.join(data_dir, type_path)
    self.file_path = self.file_path + ".txt"
    self.max_len = max_len
    self.tokenizer = tokenizer
    self.inputs = []
    self.targets = []

    self._build()
  
  def __len__(self):
    return len(self.inputs)
  
  def __getitem__(self, index):
    
    source_ids = self.inputs[index]["input_ids"].squeeze()
    target_ids = self.targets[index]["input_ids"].squeeze()
    

    src_mask    = self.inputs[index]["attention_mask"].squeeze()  
    target_mask = self.targets[index]["attention_mask"].squeeze() 
    
    return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}
  
  def _build(self):
    self._build_examples_from_files(self.file_path)

  def _build_examples_from_files(self, path):
    df = pd.read_csv(path, sep="\t")
    dataframe = df[["description", "SMILES"]]
    
    for i, line in dataframe.iterrows():

      tokenized_inputs = self.tokenizer.batch_encode_plus(
          [line['description']], max_length=self.max_len, pad_to_max_length=True, return_tensors="pt"
      )
      tokenized_inputs['input_ids'][0][9:] = tokenized_inputs['input_ids'][0][7:-2].clone()
      tokenized_inputs['input_ids'][0][7] = 32101 + i % 10
      tokenized_inputs['input_ids'][0][8] = 32111 + i
      
      tokenized_inputs['attention_mask'][0][9:] = tokenized_inputs['attention_mask'][0][7:-2].clone()
      tokenized_inputs['attention_mask'][0][8] = 1
      tokenized_inputs['attention_mask'][0][7] = 1
      
      tokenized_targets = self.tokenizer.batch_encode_plus(
          [line['SMILES']], max_length=self.max_len, pad_to_max_length=True, return_tensors="pt"
      )
      
      self.inputs.append(tokenized_inputs)
      self.targets.append(tokenized_targets)
      
