import csv
import torch
import numpy as np

DATASETS = ["harmful_behaviors", "hp_qa_en", "harmful_strings", "forget01", "forget01_inst"]

def load_dataset_and_dataloader(tokenizer, dataset_name, batch_size, 
                                csv_columns=[0, 1], test_split=0, shuffle=True, device="cuda:0"):
    file_path = f"data/{dataset_name}.csv"
    if dataset_name not in DATASETS:
        raise ValueError(f"Dataset {dataset_name} not found. Choose from {DATASETS}")

    if dataset_name == "harmful_strings":
        csv_columns = [0] # only contains one column 

    reader = csv.reader(open(file_path, 'r'))
    cols = next(reader)
    print(f"Using columns: '{[cols[idx] for idx in csv_columns]}' from dataset '{dataset_name}'")
    dataset = list(reader)
    dataset_train, dataset_test, dataloader_train, dataloader_test = create_pytorch_dataset_from_csv(
        tokenizer, dataset, dataset_name, batch_size, csv_columns=csv_columns, test_split=test_split, shuffle=shuffle, device=device
        )
    
    return dataset_train, dataset_test, dataloader_train, dataloader_test

def create_pytorch_dataset_from_csv(tokenizer, string_list, dataset_name, batch_size, csv_columns=[0, 1], test_split=0, shuffle=False, device="cuda:0"):
    '''
    Create a PyTorch dataset from a list of string tuples.

    Args:
        model (torch.nn.Module): The PyTorch model.
        tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to tokenize the strings.
        string_list (list): A list of tuples (X, Y) where X is the input string and Y is the target string.

    Returns:
        torch.utils.data.TensorDataset: The PyTorch dataset containing the input and target tensors.
        torch.utils.data.DataLoader: The PyTorch dataloader containing the dataset.
    '''
    tensor_list = []

    for column_idx in csv_columns:
        X_str = [row[column_idx] for row in string_list]    
        X_token = tokenizer(X_str, padding=True)["input_ids"]
        X = torch.tensor(X_token, device=device)
        tensor_list.append(X)

    # split into train and test set
    if test_split > 0:
        split_idx = int(len(X) * (1 - test_split))
        train_tensor_list = []
        test_tensor_list = []
        for tensor in tensor_list:
            X_train, X_test = tensor[:split_idx], tensor[split_idx:]
            train_tensor_list.append(X_train)
            test_tensor_list.append(X_test)
        dataset_train = torch.utils.data.TensorDataset(*train_tensor_list)
        dataset_test = torch.utils.data.TensorDataset(*test_tensor_list)
        dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle)
        dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False)
        print(f"Dataset: {dataset_name} | Split dataset into train and test set with {len(X_train)} train and {len(X_test)} test samples")
    else:
        dataset_train = torch.utils.data.TensorDataset(*tensor_list)
        dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle)
        dataset_test = None
        dataloader_test = None
        print(f"Dataset: {dataset_name} | Using whole dataset as training data with {len(X)} rows")

    
    return dataset_train, dataset_test, dataloader_train, dataloader_test


class StringListsDataset(torch.utils.data.Dataset):
    def __init__(self, data, max_length = 256):
        data = [d[:max_length] for d in data]
        self.data = np.array(data, dtype=object)
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, indices):
        batch = self.data[indices]
        return batch
