import torch
import csv
import pandas as pd


class OurDatasets(torch.utils.data.Dataset):
    def __init__(self, csv_file, header=None, col_inds=[0], drop_dup=False, sep=None):
        """
        csv_file: path to data
        header: does the dataset have a header row
        col: index of the columns that contain the data to use, default the first
        """
        self.data = pd.read_csv(csv_file, quoting=csv.QUOTE_ALL, header=header, sep=sep, on_bad_lines='skip') # read data
        all_cols = [i for i in range(len(self.data.columns))] # numerate columns
        col_names = [str(a) for a in all_cols] # make strings out of it
        self.data.columns = col_names # rename the columns with these

        # which columns to use
        keep_cols = [col_names[i] for i in range(len(all_cols)) if all_cols[i] in col_inds]
        self.data = pd.concat([self.data[col] for col in keep_cols]) # get the specified columns - flatten
        if drop_dup:
            self.data.drop_duplicates(inplace=True)
        self.data = self.data.sample(frac=1)  # do a random shuffle

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # to get the first column: df_sentences.iloc[:, 0]
        sentence = self.data.iloc[idx] # just return the sentence at the index
        return sentence

class NLIStealFirstColumn(torch.utils.data.Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file, quoting=csv.QUOTE_ALL, header=0)
        self.data = self.data.sample(frac=1)  # do a random shuffle

    # C:\Users\Franziska\Documents\Code\SimCSE-Steal\data\nli\nli_for_simcse.csv: has 275602 rows with 3 sentences each.
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # to get the first column: df_sentences.iloc[:, 0]
        sentence = self.data.iloc[idx][0]  # just return the first sentence (from the row at index idx, which has several columns)
        return sentence


class NLIStealFullDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file, quoting=csv.QUOTE_ALL, header=0)
        self.data = pd.concat([self.data[col] for col in self.data])
        self.data = self.data.sample(frac=1) # do a random shuffle

    # C:\Users\Franziska\Documents\Code\SimCSE-Steal\data\nli\nli_for_simcse.csv
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # to get the first column: df_sentences.iloc[:, 0]
        sentence = self.data.iloc[idx] # just return the sentence at the index
        return sentence

class TwoColumnDSFirstColumn(torch.utils.data.Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file, quoting=csv.QUOTE_ALL, header=None)
        self.data = self.data.sample(frac=1)  # do a random shuffle

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # to get the first column: df_sentences.iloc[:, 0]
        sentence = self.data.iloc[idx][0]  # just return the first sentence (from the row at index idx, which has several columns)
        return sentence

class TwoColumnFullDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file, quoting=csv.QUOTE_ALL, header=None)
        self.data = pd.concat([self.data[col] for col in self.data])
        self.data = self.data.sample(frac=1)  # do a random shuffle

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # to get the first column: df_sentences.iloc[:, 0]
        sentence = self.data.iloc[idx]  # just return the sentence at the index
        return sentence


class TestDataset(torch.utils.data.Dataset):
    """The test dataset makes some assumptions on the CSV it gets.
    It expects the test data in the 'cleaned for DI'-format: 1 column, no header, no repetitions"""
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file, quoting=csv.QUOTE_ALL, header=0)
        self.data = self.data.sample(frac=1)  # do a random shuffle

    # C:\Users\Franziska\Documents\Code\SimCSE-Steal\data\nli\nli_for_simcse.csv: has 275602 rows with 3 sentences each.
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # to get the first column: df_sentences.iloc[:, 0]
        sentence = self.data.iloc[idx][0]  # just return the first sentence (from the row at index idx, which has several columns)
        return sentence