import os
import datasets
import pandas as pd
import torch

def custom_data_collator_arxiv(samples):
    input_ids = torch.stack([s[0] for s in samples])
    attention_mask = torch.stack([s[1] for s in samples])
    labels = input_ids.clone()
    return (input_ids, labels, attention_mask)

def custom_data_collator_with_indices(samples):
    input_ids = [s[0] for s in samples]
    attention_mask = [s[1] for s in samples]
    indices = [s[2] for s in samples]
    return torch.stack(input_ids), torch.stack(attention_mask), torch.stack(indices)

def custom_data_collator_forget(samples):
    forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples]

    rets = []
    for data_type in ["forget", "retain"]:
        data = forget_samples if data_type == "forget" else retain_samples
        input_ids = torch.stack([s[0] for s in data])
        attention_mask = torch.stack([s[1] for s in data])
        labels = input_ids.clone()
        rets.append((input_ids, labels, attention_mask))
        
    return rets

def load_tofu_train_dataset(
    dataset_dir: str = None,
    hf_dataset_name: str = None,
    hf_dataset_split: str = None,
    is_wtm: bool = False,
    forget_ratio: float = 0.0,
    dup: bool = False,
):
    if dataset_dir is not None:
        train_data = datasets.load_from_disk(dataset_dir)
        forget_data, retain_data = split_dataset(train_data, forget_ratio)
    else:
        data = datasets.load_dataset(hf_dataset_name, hf_dataset_split)
        train_data = data['full']
        forget_data, retain_data = split_dataset(train_data, forget_ratio)
        # forget_data = data['forget']
        # retain_data = data['retain']
        
    if dup:
        # append the duplicated forget set to retain_data and train_data
        train_data, retain_data = duplicate_forget_data(train_data,
                                                        retain_data,
                                                        is_arxiv_data=False
                                                        )

    return train_data, forget_data, retain_data

def load_arxiv_train_dataset(
    dataset_dir: str = None,
    hf_dataset_name: str = None,
    hf_dataset_split: str = None,
    is_wtm: bool = False,
    forget_ratio: float = 0.0,
    dup: bool = False,
    **kwargs,
):
    if dataset_dir is not None:
        df = pd.read_pickle(dataset_dir)    # Abused name: assume dataset_dir is a pickle file
        text = df['Summary'].tolist()
        train_data = datasets.Dataset.from_dict({'text': text})
        forget_data, retain_data = split_dataset(train_data, forget_ratio)
    else:
        data = datasets.load_dataset(hf_dataset_name, hf_dataset_split)
        train_data = data['full']
        # forget_data, retain_data = split_dataset(train_data, forget_ratio)
        forget_data = data['forget']
        retain_data = data['retain']
    
    if dup:
        # append the duplicated forget set to retain_data and train_data
        train_data, retain_data = duplicate_forget_data(train_data,
                                                        retain_data,
                                                        is_arxiv_data=True
                                                        )

    return train_data, forget_data, retain_data


def split_dataset(train_data, forget_ratio: float = None):
    if forget_ratio is None or forget_ratio == 0.0:
        print(f'[WARNING] Forget set is empty (forget_ratio = {forget_ratio})')
        forget_data = None
        retain_data = train_data
    else:
        assert 0.0 < forget_ratio <= 1.0
        total_num_rows = len(train_data)
        forget_num_rows = int(total_num_rows * forget_ratio)
        retain_num_rows = total_num_rows - forget_num_rows
        retain_data = train_data.select(range(0, retain_num_rows))
        forget_data = train_data.select(range(retain_num_rows, total_num_rows))

    return forget_data, retain_data


def duplicate_forget_data(
    train_data,
    retain_data,
    is_arxiv_data: bool = False,
):

    print('Duplicating forget set...')
    if is_arxiv_data:
        data = datasets.load_dataset('Glow-AI/WaterDrum-Ax', 'unwatermarked_forget_01')
        dup_data = data['semantic_duplicate']
    else:   # TOFU
        data = datasets.load_dataset('Glow-AI/WaterDrum-TOFU', 'unwatermarked_forget_10')
        dup_data = data['semantic_duplicate']
        
        
    # add duplicated samples to retain set and train set
    train_data = datasets.concatenate_datasets([train_data, dup_data])
    retain_data = datasets.concatenate_datasets([retain_data, dup_data])
    print('num_duplicated_rows:', len(dup_data))

    return train_data, retain_data

