import os
from turtle import pos
import pandas as pd
from typing import *
from .sentiment_analysis_dataset import PROCESSORS as SA_PROCESSORS
from .text_classification_dataset import PROCESSORS as TC_PROCESSORS
from .plain_dataset import PROCESSORS as PT_PROCESSORS
from .toxic_dataset import PROCESSORS as TOXIC_PROCESSORS
from .spam_dataset import PROCESSORS as SPAM_PROCESSORS
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
# from openbackdoor.utils.log import logger
import logging
logger = logging.getLogger(__name__)

import torch
# support loading transformers datasets from https://huggingface.co/docs/datasets/

PROCESSORS = {
    **SA_PROCESSORS,
    **TC_PROCESSORS,
    **PT_PROCESSORS,
    **TOXIC_PROCESSORS,
    **SPAM_PROCESSORS,
}


def load_dataset(
            test=False, 
            name: str = "sst-2",
            dev_rate: float = 0.1,
            load: Optional[bool] = False,
            clean_data_basepath: Optional[str] = None,
            **kwargs):
    r"""A plm loader using a global config.
    It will load the train, valid, and test set (if exists) simulatenously.
    
    Args:
        config (:obj:`dict`): The global config from the CfgNode.
    
    Returns:
        :obj:`Optional[List]`: The train dataset.
        :obj:`Optional[List]`: The valid dataset.
        :obj:`Optional[List]`: The test dataset.
        :obj:"
    """

    if load and os.path.exists(clean_data_basepath):
        train_dataset = load_clean_data(clean_data_basepath, "train-clean")
        dev_dataset = load_clean_data(clean_data_basepath, "dev-clean")
        test_dataset = load_clean_data(clean_data_basepath, "test-clean")

        dataset = {
            "train": train_dataset,
            "dev": dev_dataset,
            "test": test_dataset
        }
        return dataset


    processor = PROCESSORS[name.lower()]()
    dataset = {}
    train_dataset = None
    dev_dataset = None

    if not test:
        try:
            train_dataset = processor.get_train_examples()
        except FileNotFoundError:
            logger.warning("Has no training dataset.")
        try:
            dev_dataset = processor.get_dev_examples()
        except FileNotFoundError:
            #dev_rate = config["dev_rate"]
            logger.warning("Has no dev dataset. Split {} percent of training dataset".format(dev_rate*100))
            train_dataset, dev_dataset = processor.split_dev(train_dataset, dev_rate)

    test_dataset = None
    try:
        test_dataset = processor.get_test_examples()
    except FileNotFoundError:
        logger.warning("Has no test dataset.")

    # checking whether donwloaded.
    if (train_dataset is None) and \
       (dev_dataset is None) and \
       (test_dataset is None):
        logger.error("{} Dataset is empty. Either there is no download or the path is wrong. ".format(name)+ \
        "If not downloaded, please `cd datasets/` and `bash download_xxx.sh`")
        exit()

    dataset = {
        "train": train_dataset,
        "dev": dev_dataset,
        "test": test_dataset
    }
    logger.info("{} dataset loaded, train: {}, dev: {}, test: {}".format(name, len(train_dataset), len(dev_dataset), len(test_dataset)))
    

    return dataset

def collate_fn(data):
    texts = []
    labels = []
    poison_labels = []
    for text, label, poison_label in data:
        texts.append(text)
        labels.append(label)
        poison_labels.append(poison_label)
    labels = torch.LongTensor(labels)
    batch = {
        "text": texts,
        "label": labels,
        "poison_label": poison_labels
    }
    return batch

def get_dataloader(dataset: Union[Dataset, List],
                    batch_size: Optional[int] = 4,
                    shuffle: Optional[bool] = True):
    return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)





def collate_fn_attn_version(data):
    '''
    Uniform the clean samples and poisoned samples
    output the dictionary:
        text: clean samples or poisoned samples
        label: corresponding text labels
        poison_label: whether the sample is a clean or poisoned sample. int: 0 or 1
        position: the trigger insertion position. Using function - word.insert(position).
            If it's a clean sample, then position == 0
        ori_text: original text before inserting triggers. 
            If clean sample, then text==ori_text.    
        trigger: trigger text. 
            If clean sample, then trigger=''.
    '''
    texts = []
    labels = []
    poison_labels = []
    positions = []
    ori_texts = []
    triggers = []
    # uniform the clean samples and poisoned samples
    for text, label, poison_label in data:
        if poison_label == 1:# poisoned label
            ## text sentence, text, position, self.triggers
            # print('./data/__init__.py: function collate_fn_attn_version - mixed list(attn_poisoned.py - self.poison): {}'.format(text))
            position = text[2]
            ori_text = text[1]
            trigger = text[3] # should be list??
            text = text[0] # poisoned sample
            # print('./data/__init__.py: function collate_fn_attn_version - position {}, text: {}'.format(position, text))
            # a1 = text.split()[:position]
            # a2 = text.split()[position]
            # a3 = text.split()[position+1:]
            # print('a1 {} \na2 {} \na3 {}'.format(a1, a2, a3))
        else:
            position = 0
            ori_text = text
            trigger = ''
        texts.append(text)
        labels.append(label)
        poison_labels.append(poison_label)
        positions.append(position)
        ori_texts.append(ori_text)
        triggers.append(trigger)
    labels = torch.LongTensor(labels)
    batch = {
        "text": texts,
        "label": labels,
        "poison_label": poison_labels,
        "position": positions,
        "ori_text": ori_texts, # store original text. check 'plms.py' - func process_attn.
        "trigger": triggers # trigger text
    }
    return batch

def get_dataloader_attn_version(dataset: Union[Dataset, List],
                    batch_size: Optional[int] = 4,
                    shuffle: Optional[bool] = True):
    return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn_attn_version)





def load_clean_data(path, split):
        # clean_data = {}
        data = pd.read_csv(os.path.join(path, f'{split}.csv')).values
        clean_data = [(d[1], d[2], d[3]) for d in data]
        return clean_data

from .data_utils import wrap_dataset, wrap_dataset_lws, wrap_dataset_attn
