import numpy as np
import pickle
from datasets import load_dataset, DatasetDict
from tqdm import tqdm
from transformers import AutoTokenizer


MAX_LEN = 2048


def filter_reward(dataset, columns, reward_to_filter, mask):

    output_mask = mask

    for i, row in tqdm(enumerate(dataset), total=len(dataset)):
        if mask[i] == 0:
            continue
        for c in columns:
            if reward_to_filter in row[c]:
                output_mask[i] = 0
                break
    
    return output_mask


def filter_same_response(dataset, col_1, col_2, turn, mask):

    output_mask = mask

    if type(turn) == str:
        for i, row in tqdm(enumerate(dataset), total=len(dataset)):
            if mask[i] == 0:
                continue
            if row[col_1][row[turn]*2 + 1] == row[col_2][row[turn]*2 + 1]:
                output_mask[i] = 0
    elif type(turn) == int:
        for i, row in tqdm(enumerate(dataset), total=len(dataset)):
            if mask[i] == 0:
                continue
            if row[col_1][turn] == row[col_2][turn]:
                output_mask[i] = 0
    
    return output_mask


def filter_length(dataset, tokenizer, columns, max_len, mask, turns = None):

    output_mask = mask

    for i, row in tqdm(enumerate(dataset), total=len(dataset)):
        if mask[i] == 0:
            continue
        for c, t in zip(columns, turns):
            if t:
                if len(tokenizer.apply_chat_template(row[c][:(row[t]+1)*2], tokenize=True, add_generation_prompt=False)) > max_len:
                    output_mask[i] = 0
                    break
            elif len(tokenizer.apply_chat_template(row[c], tokenize=True, add_generation_prompt=False)) > max_len:
                output_mask[i] = 0
                break
    
    return output_mask


def generate_token_mask(dataset, tokenizer, col, turns, max_len):

    # init
    token_masks = {}
    for t in turns:
        token_masks[col + f'_turn={t}_token'] = []
        token_masks[col + f'_turn={t}_mask'] = []

    for row in tqdm(dataset[col]):
        for t in turns:
            if t == -1:
                token = tokenizer.apply_chat_template(row, tokenize=True, add_generation_prompt=False, padding='max_length', max_length=max_len)
                mask = np.zeros(max_len).astype(int)
                mask[len(tokenizer.apply_chat_template(row[:-1], tokenize=True, add_generation_prompt=True)) : len(tokenizer.apply_chat_template(row, tokenize=True, add_generation_prompt=False))] = 1
            else:
                token = tokenizer.apply_chat_template(row[:(t+1)*2], tokenize=True, add_generation_prompt=False, padding='max_length', max_length=max_len)
                mask = np.zeros(max_len).astype(int)
                mask[len(tokenizer.apply_chat_template(row[:t*2+1], tokenize=True, add_generation_prompt=True)) : len(tokenizer.apply_chat_template(row[:(t+1)*2], tokenize=True, add_generation_prompt=False))] = 1
            token_masks[col + f'_turn={t}_token'].append(token)
            token_masks[col + f'_turn={t}_mask'].append(mask)

    return token_masks


def main():

    # init
    dataset_multi = load_dataset('TBD', split='train')

    tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    # ================================filter====================================
    filter_mask = np.ones(len(dataset_multi))

    # filter reward with -99999
    print('filtering reward')
    filter_mask = filter_reward(dataset_multi, ['trajectory_reward', 'trajectory_sampled_h_from_5_reward', 'trajectory_H-1_reward', 'trajectory_sampled_h_from_sampled_len_reward'], -99999, filter_mask)
    print(f'after filtering: {filter_mask.sum()}')

    # filter same response
    print('filtering same response')
    filter_mask = filter_same_response(dataset_multi, 'trajectory', 'trajectory_sampled_h_from_5', 'sampled_h_from_5', filter_mask)
    filter_mask = filter_same_response(dataset_multi, 'trajectory', 'trajectory_sampled_h_from_sampled_len', 'sampled_h_from_sampled_len', filter_mask)
    filter_mask = filter_same_response(dataset_multi, 'trajectory', 'trajectory_H-1', -1, filter_mask)
    print(f'after filtering: {filter_mask.sum()}')

    # filter length
    print('filtering length')
    filter_mask = filter_length(dataset_multi, tokenizer, ['trajectory', 'trajectory_sampled_h_from_5', 'trajectory_sampled_h_from_sampled_len', 'trajectory_H-1'], MAX_LEN, filter_mask, \
                                turns=[None, 'sampled_h_from_5', 'sampled_h_from_sampled_len', None])
    print(f'after filtering: {filter_mask.sum()}')

    with open('temp.pkl', 'wb') as handle:
        pickle.dump(filter_mask, handle, protocol=pickle.HIGHEST_PROTOCOL)

    dataset_multi = dataset_multi.filter(lambda _, idx: filter_mask[idx] != 0, with_indices=True)

    # ===========================generate & mask================================

    # generate token and mask
    print('generating tokens and masks')
    token_masks = generate_token_mask(dataset_multi, tokenizer, 'trajectory', [0, 1, 2, 3, 4], MAX_LEN)
    for k, v in token_masks.items():
        dataset_multi = dataset_multi.add_column(k, v)
    token_masks = generate_token_mask(dataset_multi, tokenizer, 'trajectory_sampled_h_from_5', [0, 1, 2, 3, 4], MAX_LEN)
    for k, v in token_masks.items():
        dataset_multi = dataset_multi.add_column(k, v)
    token_masks = generate_token_mask(dataset_multi, tokenizer, 'trajectory_sampled_h_from_sampled_len', [0, 1, 2, 3, 4], MAX_LEN)
    for k, v in token_masks.items():
        dataset_multi = dataset_multi.add_column(k, v)
    token_masks = generate_token_mask(dataset_multi, tokenizer, 'trajectory_H-1', [4], MAX_LEN)
    for k, v in token_masks.items():
        dataset_multi = dataset_multi.add_column(k, v)

    # ==========================================================================

    # split
    dataset_multi = dataset_multi.train_test_split(test_size=500)
    dataset_multi.push_to_hub('TBD')


if __name__ == "__main__":
    main()
    