

import os
import random

import torch
import findfile


def load_dataset(
    task_name="defense",
    dataset_name="sst2",
    split="train",
    data_dir=None,
    attack="",
):
    if not data_dir:
        data_dir = "datasets"

    data_files = findfile.find_files(data_dir, [task_name, dataset_name, split, attack])

    data = []

    if data_files:
        for data_file in data_files:
            with open(data_file, "r", encoding="utf8") as f:
                lines = f.readlines()[:10000]
                for line in lines:
                    text, label = line.split("$LABEL$")
                    data.append({"text": text.strip(), "label": label})
    random.shuffle(data)
    return data


class Dataset:
    def __init__(
        self,
        tokenizer,
        task_name="defense",
        dataset_name="sst2",
        split="train",
        data_dir=None,
        attack="",
        max_length=512,
    ):
        self.task_name = task_name
        self.dataset_name = dataset_name
        self.split = split
        self.data_dir = data_dir
        self.kword = attack
        self.dataset = load_dataset(
            self.task_name, self.dataset_name, self.split, self.data_dir, self.kword
        )

        self.tokenizer = tokenizer

        for i in range(len(self.dataset)):
            cls_label, det_label, adv_label = self.dataset[i]["label"].split(",")
            self.dataset[i] = self.tokenizer(
                self.dataset[i]["text"],
                max_length=max_length,
                padding="max_length",
                truncation=True,
            )
            self.dataset[i]["label"] = int(cls_label)
            self.dataset[i]["adv_label"] = int(adv_label)
            self.dataset[i]["det_label"] = int(det_label)
            for col in self.dataset[i]:
                self.dataset[i][col] = torch.tensor(self.dataset[i][col])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]


class CLSDataset:
    def __init__(
        self,
        tokenizer,
        task_name="classification",
        dataset_name="sst2",
        split="train",
        data_dir=None,
        attack="",
        max_length=512,
    ):
        self.task_name = task_name
        self.dataset_name = dataset_name
        self.split = split
        self.data_dir = data_dir
        self.kword = attack
        self.dataset = load_dataset(
            self.task_name, self.dataset_name, self.split, self.data_dir, self.kword
        )

        self.tokenizer = tokenizer

        for i in range(len(self.dataset)):
            cls_label = self.dataset[i]["label"]
            self.dataset[i] = self.tokenizer(
                self.dataset[i]["text"],
                max_length=max_length,
                padding="max_length",
                truncation=True,
            )
            self.dataset[i]["label"] = int(cls_label)
            for col in self.dataset[i]:
                self.dataset[i][col] = torch.tensor(self.dataset[i][col])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]
