import os.path
import pickle

from transformers import AutoTokenizer
from datasets import load_dataset
from utils import PathConfig
from collections import Counter
from val_sampler import sample_val
from datasets import concatenate_datasets


PC = PathConfig()
DATASET_PATH = PC.get_dataset_path()


def post_process(dst, model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
    def tokenize_function(dst):
        return tokenizer(dst['text'], padding='max_length', truncation=True, max_length=512)
    dst = dst.map(tokenize_function, batched=True)
    dst = dst.remove_columns(["text"])
    dst = dst.rename_column("label", "labels")
    return dst


def decode_to_sentence(dst, model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
    def decode(data):
        return tokenizer.decode(data, skip_special_tokens=True)
    dst = dst.map(lambda example: {'text': decode(example['input_ids'])} , num_proc=4)
    return dst


def get_num_class(dst_name):
    num_class_dict = {'newsgroups':20, 'imdb':2, 'reuters':2}
    return num_class_dict[dst_name]


def get_imbalanced_ratio(data_y, is_return_minimum=False):
    class_ratio_dict = sorted(Counter(data_y).items(), key=lambda x: x[0])
    class_ratio_dict = {item[0]:item[1] for item in class_ratio_dict}
    min_num = min(list(class_ratio_dict.values()))
    for k in class_ratio_dict.keys():
        class_ratio_dict[k] = class_ratio_dict[k]/min_num
    if is_return_minimum:
        return class_ratio_dict, min_num
    else:
        return class_ratio_dict


def read_data_pool(folder_path, num_class, model_name):
    data_pool_dict ={}
    for i in range(num_class):
        dst = load_dataset('csv', data_files=os.path.join(folder_path, str(i) + '.csv'))['train']
        data_pool_dict[i] = post_process(dst, model_name)
        data_pool_dict[i].set_format('torch')
    return data_pool_dict


def read_src_index_file(file_path):
    with open(os.path.join(file_path, 'src_index.pkl'), 'rb') as fi:
        src_index_dict = pickle.load(fi)
    return src_index_dict


def read_val_set(dst_index, pool_folder_path, num_class, model_name):
    data_list = []
    for i in range(num_class):
        class_dst = post_process(
            load_dataset('csv', data_files=os.path.join(pool_folder_path, str(i) + '.csv'))['train'].select(dst_index[i]), model_name)
        class_dst.set_format('torch')
        data_list.append(class_dst)
    return concatenate_datasets(data_list)


def get_dataset(dst_name, model_name, seed=None, is_post_process=True):
    if dst_name == 'reuters':
        dst = load_dataset('csv', data_files={
            "train": os.path.join(DATASET_PATH, 'reuters/reuters_train.csv'),
            "test": os.path.join(DATASET_PATH, 'reuters/reuters_test.csv'),
        })
        train_dst, test_dst = dst['train'], dst['test']
        if is_post_process:
            train_dst = post_process(train_dst, model_name)
            test_dst = post_process(test_dst, model_name)
            print("train class statistics", sorted(Counter(train_dst["labels"]).items(), key=lambda x: x[0]))
            print("test class statistics", sorted(Counter(test_dst["labels"]).items(), key=lambda x: x[0]))
        print('train shape', train_dst.shape)
        print('test shape', test_dst.shape)
        return train_dst, test_dst
    elif dst_name == 'reuters_data_pool':
        pool_set = read_data_pool(PC.get_data_pool_path('reuters'), get_num_class('reuters'), model_name)
        src_index_dict = read_src_index_file(PC.get_data_pool_path('reuters'))
        return pool_set, src_index_dict


def get_val_set(dst_name, model_name, seed=None, method='DB_ADJOINT', **kwargs):
    train_set, _ = get_dataset(dst_name, model_name, is_post_process=False)
    imbalance_ratio_list, min_num = get_imbalanced_ratio(train_set['label'], is_return_minimum=True)
    if method == 'LZO':
        dst_indexes = sample_val(dst_name=dst_name, method='RANDOM', val_num_per_class=min_num, random_seed=seed,
                                 class_ratio_list=imbalance_ratio_list)
    else:
        dst_indexes = sample_val(dst_name=dst_name, method=method, val_num_per_class=kwargs['val_num_per_class']
                                 , random_seed=seed, class_ratio_list=imbalance_ratio_list)
    val_set = read_val_set(dst_indexes, PC.get_data_pool_path(dst_name), get_num_class(dst_name), model_name)
    return val_set


def split_dst_by_class(dst, num_class):
    splited_dst = {}
    label_key = 'labels' if 'labels' in dst.column_names else 'label'
    for c in range(num_class):
        splited_dst[c] = dst.filter(lambda e: e[label_key] == c)
    return splited_dst


if __name__ == '__main__':
    train, test = get_dataset('reuters', model_name = "bert-base-uncased")
    print(len(train) + len(test))
    print(len(train.filter(lambda example: example['labels'] == 1))+ len(test.filter(lambda example: example['labels'] == 1)))