import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import re

class PositiveNegativeIdsDataset(Dataset):
    def __init__(self, positive_csv, negative_csv, tokenizer, max_length):
        # Read the CSV files
        self.positive_prompts = pd.read_csv(positive_csv)['prompt'].tolist()
        self.negative_prompts = pd.read_csv(negative_csv)['prompt'].tolist()
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return min(len(self.positive_prompts), len(self.negative_prompts))
    
    def __getitem__(self, idx):
        # Get the positive and negative prompts
        positive_prompt = self.positive_prompts[idx]
        negative_prompt = self.negative_prompts[idx]
        
        # Tokenize the prompts
        positive_encoded = self.tokenizer(
            positive_prompt, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors='pt'
        )
        
        negative_encoded = self.tokenizer(
            negative_prompt, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors='pt'
        )
        
        return positive_encoded['input_ids'].squeeze(0), negative_encoded['input_ids'].squeeze(0)


class PositiveNegativeStrDataset(Dataset):
    def __init__(self, positive_csv, negative_csv, tokenizer, max_length):
        # Read the CSV files
        self.positive_prompts = pd.read_csv(positive_csv)['prompt'].tolist()
        self.negative_prompts = pd.read_csv(negative_csv)['prompt'].tolist()
        
    def __len__(self):
        return min(len(self.positive_prompts), len(self.negative_prompts))
    
    def __getitem__(self, idx):
        # Get the positive and negative prompts
        positive_prompt = self.positive_prompts[idx]
        negative_prompt = self.negative_prompts[idx]
        
        return positive_prompt, negative_prompt
    

class SingleStrDataset(Dataset):
    def __init__(self, csv, tokenizer, max_length):
        # Read the CSV files
        if 'ringbell' in csv:
            self.prompts = pd.read_csv(csv, encoding='ISO-8859-1')['prompt'].tolist()
        else:
            self.prompts = pd.read_csv(csv)['prompt'].tolist()
        
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        # Get negative prompts
        prompt = self.prompts[idx]
        
        return prompt
    

class AdvPositiveNegativeIdsDataset(Dataset):
    def __init__(self, positive_csv, negative_csv, tokenizer, max_length, train=False):
        # Read the CSV files
        self.positive_prompts = pd.read_csv(positive_csv)['prompt'].tolist()
        self.negative_prompts = pd.read_csv(negative_csv)['prompt'].tolist()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augmenter = WordNetAugmenter(transformations_per_example=1)
        self.train = train
        
    def __len__(self):
        return min(len(self.positive_prompts), len(self.negative_prompts))
    
    def __getitem__(self, idx):
        # Get the positive and negative prompts
        if self.train:
            positive_prompt = self.augmenter.augment(self.positive_prompts[idx])[0]
            negative_prompt = self.augmenter.augment(self.negative_prompts[idx])[0]
        else:
            positive_prompt = self.positive_prompts[idx]
            negative_prompt = self.negative_prompts[idx]
        
        # Tokenize the prompts
        positive_encoded = self.tokenizer(
            positive_prompt, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors='pt'
        )
        
        negative_encoded = self.tokenizer(
            negative_prompt, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors='pt'
        )
        
        return positive_encoded['input_ids'].squeeze(0), negative_encoded['input_ids'].squeeze(0)