import pickle
import random
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset, Dataset, concatenate_datasets, DatasetDict
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from misc.utils import get_num_classes
import os

data_path = './data'

ratio_train = 0.6
seed = 1234
n_client = [20]
data_all = ['20_newsgroups', 'multi_sent', 'ag_news']
INF = np.iinfo(np.int64).max

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

def get_tokenizer(model_path: str, max_length: int):
    return AutoTokenizer.from_pretrained(model_path, max_length=max_length, local_files_only=True)

max_length = 256

tokenizer = get_tokenizer("your_local_path_to_model/distilbert-base-multilingual-cased", max_length)

def data_load(dataset='20_newsgroups', client_num=20, partition='random', beta=0.5,
              model='distilbert-base-multilingual-cased', datadir='./data'):
    print('* Partitioning data (num_party: {} by {}, beta: {})'.format(client_num, partition, beta))
    if dataset == "multi_sent":
        client2data, nlp_dataset = partition_multi_lang_data(dataset=dataset,
                                                             model=model,
                                                             datadir=datadir,
                                                             partition=partition,
                                                             n_parties=client_num,
                                                             beta=beta)
    elif dataset == "20_newsgroups":
        client2data, nlp_dataset = load_and_prepare_20newsgroups(dataset=dataset,
                                                                 model=model,
                                                                 datadir=datadir,
                                                                 partition=partition,
                                                                 n_parties=client_num,
                                                                 beta=beta)
    elif dataset == "ag_news":
        client2data, nlp_dataset = load_and_prepare_ag_news(dataset=dataset,
                                                            model=model,
                                                            datadir=datadir,
                                                            partition=partition,
                                                            n_parties=client_num,
                                                            beta=beta)
    else:
        exit("No implementation error")
    return client2data, nlp_dataset


def save_data(nlp_dataset, net_dataidx_map, nlp_dataset_path="./data/nlp_dataset.pkl",
              net_dataidx_map_path="./data/net_dataidx_map.pkl"):
    with open(nlp_dataset_path, "wb") as f:
        pickle.dump(nlp_dataset, f)

    with open(net_dataidx_map_path, "wb") as f:
        pickle.dump(net_dataidx_map, f)


def load_data(nlp_dataset_path="./data/nlp_dataset.pkl", net_dataidx_map_path="./data/net_dataidx_map.pkl"):
    with open(nlp_dataset_path, "rb") as f:
        nlp_dataset = pickle.load(f)

    with open(net_dataidx_map_path, "rb") as f:
        net_dataidx_map = pickle.load(f)

    return nlp_dataset, net_dataidx_map

import numpy as np
from numpy.random import dirichlet
from collections import defaultdict

def partition_multi_lang_data(dataset: str, datadir: str, model: str, partition: str, n_parties: int = 10,
                              max_length: int = 256, data_fraction: float = 0.1, data_fraction_test: float = 0.02,
                              model_checkpoint='distilbert-base-uncased', beta=0.5):
    global nlp_dataset, total_dataset

    try:
        nlp_dataset, net_dataidx_map = load_data("data/xglue_nlp_dataset.pkl", "data/xglue_net_dataidx_map.pkl")
        return net_dataidx_map, nlp_dataset
    except:
        print('dont direct load!')

    dataset = load_dataset('csv', data_files={
        'train': 'your_local_path_to_data',
        'test': 'your_local_path_to_data'
    })

    train_dataset = dataset['train']
    test_dataset = dataset['test']
    
    train_dataset_sampled = train_dataset.shuffle(seed=42).select(range(int(len(train_dataset) * data_fraction)))
    test_dataset_sampled = test_dataset.shuffle(seed=42).select(range(int(len(test_dataset) * data_fraction_test)))
    
    full_dataset = DatasetDict({'train': train_dataset_sampled, 'test': test_dataset_sampled})

    tokenizer = get_tokenizer(model_path="your_local_path_to_model/distilbert-base-multilingual-cased", max_length=max_length)

    def tokenize_function(examples):
        return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=max_length)

    tokenized_dataset = full_dataset.map(tokenize_function, batched=True)

    all_train_indices = np.arange(len(tokenized_dataset['train']))

    if beta is not None and beta > 0:
        labels = tokenized_dataset['train']['label'] 

        class_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            class_indices[label].append(idx)

        net_dataidx_map = {i: [] for i in range(n_parties)}
        for label, indices in class_indices.items():
            np.random.shuffle(indices)
            proportions = dirichlet([beta] * n_parties)
            proportions_cumsum = (np.cumsum(proportions) * len(indices)).astype(int)[:-1]
            split_indices = np.split(indices, proportions_cumsum)

            for i, idx in enumerate(split_indices):
                net_dataidx_map[i].extend(idx.tolist())
        
    else:
        np.random.shuffle(all_train_indices)
        indices_split = np.array_split(all_train_indices, n_parties)
        net_dataidx_map = {i: indices_split[i].tolist() for i in range(n_parties)}

    nlp_dataset = tokenized_dataset.remove_columns(['text'])
    nlp_dataset.set_format("torch")

    return net_dataidx_map, nlp_dataset


def load_and_prepare_20newsgroups(dataset: str, datadir: str, model: str, partition: str, n_parties: int = 10,
                                  beta: float = 0.5, max_length: int = 256, model_checkpoint='distilbert-base-uncased'):
    global nlp_dataset, total_dataset

    try:
        nlp_dataset, net_dataidx_map = load_data("data/20_nlp_dataset.pkl", "data/20_net_dataidx_map.pkl")
        return net_dataidx_map, nlp_dataset
    except:
        print('dont direct load!')

    newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
    newsgroups_test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))

    train_df = pd.DataFrame({'text': newsgroups_train['data'], 'label': newsgroups_train['target']})
    test_df = pd.DataFrame({'text': newsgroups_test['data'], 'label': newsgroups_test['target']})

    train_dataset = Dataset.from_pandas(train_df)
    test_dataset = Dataset.from_pandas(test_df)
    full_dataset = DatasetDict({'train': train_dataset, 'test': test_dataset})

    tokenizer = get_tokenizer(model=model, max_length=max_length)

    def tokenize_function(examples):
        return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=max_length)

    tokenized_dataset = full_dataset.map(tokenize_function, batched=True)

    train_labels = np.array(tokenized_dataset['train']['label'])

    min_size = 0
    while min_size < 10:
        idx_batch = [[] for _ in range(n_parties)]
        for k in np.unique(train_labels):
            idx_k = np.where(train_labels == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(beta, n_parties))
            proportions = np.array(
                [p * (len(idx_j) < len(train_labels) / n_parties) for p, idx_j in zip(proportions, idx_batch)])
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]

        min_size = min([len(idx_j) for idx_j in idx_batch])

    net_dataidx_map = {i: idx_batch[i] for i in range(n_parties)}

    nlp_dataset = tokenized_dataset.remove_columns(['text'])
    nlp_dataset.set_format("torch")

    save_data(nlp_dataset, net_dataidx_map, nlp_dataset_path="data/20_nlp_dataset.pkl",
              net_dataidx_map_path="data/20_net_dataidx_map.pkl")

    return net_dataidx_map, nlp_dataset

def create_dataframe(iterator):
    data = {'text': [], 'label': []}
    for label, line in iterator:
        data['text'].append(line)
        data['label'].append(label - 1)
    return pd.DataFrame(data)

def load_and_prepare_ag_news(dataset: str, datadir: str, model: str, partition: str, n_parties: int = 10,
                             beta: float = 0.4, max_length: int = 256, model_checkpoint='distilbert-base-uncased'):
    global nlp_dataset, total_dataset

    try:
        nlp_dataset, net_dataidx_map = load_data("data/ag_nlp_dataset.pkl", "data/ag_net_dataidx_map.pkl")
        return net_dataidx_map, nlp_dataset
    except:
        print('dont direct load!')

    dataset = load_dataset('csv', data_files={
        'train': 'your_local_path_to_ag_news_csv',
        'test': 'your_local_path_to_ag_news_csv'
    })

    train_dataset = dataset['train']
    test_dataset = dataset['test']

    train_dataset_sampled = train_dataset.shuffle(seed=42).select(range(int(len(train_dataset) * 0.008)))

    full_dataset = DatasetDict({'train': train_dataset_sampled, 'test': test_dataset})

    tokenizer = get_tokenizer(model_path="your_local_path/distilbert-base-multilingual-cased", max_length=max_length)


    def tokenize_function(examples):
        return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=max_length)

    tokenized_dataset = full_dataset.map(tokenize_function, batched=True)

    all_train_indices = np.arange(len(tokenized_dataset['train']))

    np.random.shuffle(all_train_indices)

    indices_split = np.array_split(all_train_indices, n_parties)

    net_dataidx_map = {i: indices_split[i].tolist() for i in range(n_parties)}

    nlp_dataset = tokenized_dataset.remove_columns(['text'])
    nlp_dataset.set_format("torch")

    return net_dataidx_map, nlp_dataset
