import os
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
import datasets
import logging
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset, concatenate_datasets
from collections import Counter, defaultdict
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import pandas as pd
import multiprocessing
from itertools import chain
import json
import bisect
import string
from copy import deepcopy
import argparse
import transformers
from transformers import (
    AutoTokenizer,
    LlamaTokenizer,
    default_data_collator,
)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='YOUR_ROOT_PATH/model/llama2-1229/Llama-2-7b-hf', help='path to LaVIT checkpoint')
    
    parser.add_argument('--dataset_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/IC', help='path to dataset')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--process_batch_size', type=int, default=200, help='process batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    parser.add_argument('--noun_phrase_frequency_threshold', type=int, default=10, help="filter the noun_phrase whose frequency lower than the threshold")
    parser.add_argument('--noun_phrase_frequency_threshold_strict', type=int, default=20, help="filter the noun_phrase whose frequency lower than the threshold")
    parser.add_argument('--noun_phrase_select', type=int, default=50, help="select N urls for each noun_phrase")
    print('Number of available cores:', multiprocessing.cpu_count())
    args = parser.parse_args()


    
    return args

def main():
    args = parse_args()
    
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_info()
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # llama_tokenizer = LlamaTokenizer.from_pretrained(args.model_path, subfolder='language_model', use_fast=False)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, legacy=False)
    tokenizer.pad_token = tokenizer.eos_token

    merged_and_deduplicated_path = os.path.join(args.dataset_dir, 'Merged_new')
    if not os.path.exists(merged_and_deduplicated_path):
        cf_dataset = load_from_disk(os.path.join(args.dataset_dir, 'CapsFusion-120M', 'datasets_new'))
        blip_dataset = load_from_disk(os.path.join(args.dataset_dir, 'BLIP', 'datasets_new'))
        for split in ['ccs', 'laion']:
            blip_dataset[split] = blip_dataset[split].add_column('caption_capsfusion', [''] * len(blip_dataset[split]))
        merged_dataset = concatenate_datasets([cf_dataset['train'], blip_dataset['ccs'], blip_dataset['laion']])
        del cf_dataset, blip_dataset
        # deduplicate
        print(f"merged_dataset: {merged_dataset.num_rows}")
        # merged_dataset: 222328586
        print(merged_dataset)
        print(merged_dataset[-1])
        
        # # version-1 pandas, be killed
        # merged_dataset = pd.DataFrame(merged_dataset)
        # merged_dataset.drop_duplicates(subset='url', inplace=True)
        # merged_dataset = Dataset.from_pandas(merged_dataset)

        # # version-2 hash & datasets, success
        # # https://huggingface.co/datasets/Finnish-NLP/mc4_fi_cleaned/blob/main/deduplicate.py
        def get_hash(example):
            """Get hash of text field."""
            return {"hash": hash(example["url"])}

        def check_uniques(example, uniques):
            """Check if current hash is still in set of unique hashes and remove if true."""
            if example["hash"] in uniques:
                uniques.remove(example["hash"])
                return True
            else:
                return False

        def unique_filter(example, uniques):
            """Filter dataset with unique values."""
            if not check_uniques(example, uniques):
                return False
            else:
                return True
        
        merged_dataset = merged_dataset.map(
            get_hash, 
            num_proc=args.process_num_workers,
            writer_batch_size=100000,
            desc="Hash",
        )

        # Deduplicate hashes
        uniques = set(merged_dataset.unique("hash"))
        frac = len(uniques) / len(merged_dataset)
        print(f"# Unique hash: {len(uniques)}")
        print(f"Fraction of duplicates: {1-frac:.2%}")
        # Unique hash: 215294645
        # Fraction of duplicates: 3.16%

        # Deduplicate data
        merged_dataset = merged_dataset.filter(
            unique_filter, 
            fn_kwargs={"uniques": uniques},
            desc="Filter duplicate urls",
        )
        merged_dataset = merged_dataset.remove_columns('hash')
        print(f"After deduplicating, merged_dataset: {merged_dataset.num_rows}")
        # After deduplicating, merged_dataset: 215294645
        merged_dataset.save_to_disk(merged_and_deduplicated_path, max_shard_size="20GB")
    else:
        merged_dataset = load_from_disk(merged_and_deduplicated_path)
    print(f"merged_dataset: {merged_dataset}")
    
    # build noun_phrase counter
    noun_phrase_counter_path = os.path.join(merged_and_deduplicated_path, 'noun_phrase_counter')
    if not os.path.exists(noun_phrase_counter_path):
        noun_phrase_counter = Counter()
        # TODO: multi-processing here
        for example in tqdm(merged_dataset, desc='build noun phrase counter'):
            noun_phrase_counter.update(example['noun_chunks'])
        noun_phrase_counter = [{"noun_phrase": key, "frequency": value} for key, value in noun_phrase_counter.items()]
        noun_phrase_counter = Dataset.from_list(noun_phrase_counter)
        noun_phrase_counter = noun_phrase_counter.sort("frequency")
        noun_phrase_counter.save_to_disk(noun_phrase_counter_path, max_shard_size="20GB")
    else:
        noun_phrase_counter = load_from_disk(noun_phrase_counter_path)
    print(f"# unique noun_chunks: {noun_phrase_counter.num_rows}")
    print(f"Top 10 frequent noun_chunks: {noun_phrase_counter[-10:]}")
    # unique noun_chunks: 279269573
    
    # build noun_phrase inverted index
    noun_phrase_inverted_index_path = os.path.join(merged_and_deduplicated_path, f'noun_phrase_inverted_index_{args.noun_phrase_frequency_threshold}')
    if not os.path.exists(noun_phrase_inverted_index_path):
        left_bound = bisect.bisect_left(noun_phrase_counter['frequency'], args.noun_phrase_frequency_threshold)
        selected_noun_phrase_counter = noun_phrase_counter.select(range(left_bound, noun_phrase_counter.num_rows))
        print(f"# unique selected noun_chunks: {selected_noun_phrase_counter.num_rows}")
        # unique selected noun_chunks: 4705117
        print(f"selected_noun_phrase_counter: {selected_noun_phrase_counter}")
        noun_phrase_inverted_index = {k: list() for k in selected_noun_phrase_counter["noun_phrase"]}

        def build_noun_phrase_inverted_index(example, idx):
            for noun_phrase in example['noun_chunks']:
                if noun_phrase in noun_phrase_inverted_index:
                    noun_phrase_inverted_index[noun_phrase].append(idx)

        merged_dataset.map(
            build_noun_phrase_inverted_index,
            with_indices=True,
            desc="Update inverted index",
        )

        def dict_gen():
            for k, v in noun_phrase_inverted_index.items():
                yield {"noun_phrase": k, "url_index": v}

        noun_phrase_inverted_index = Dataset.from_generator(dict_gen)
        noun_phrase_inverted_index.save_to_disk(noun_phrase_inverted_index_path, max_shard_size="20GB")
    else:
        noun_phrase_inverted_index = load_from_disk(noun_phrase_inverted_index_path)
    print(f"noun_phrase_inverted_index: {noun_phrase_inverted_index}")
    # 4705117
    
    # use noun_phrase_frequency_threshold_strict
    if args.noun_phrase_frequency_threshold_strict:
        # noun_phrase_inverted_index
        selected_noun_phrase_counter = noun_phrase_counter.select(range(noun_phrase_counter.num_rows - noun_phrase_inverted_index.num_rows, noun_phrase_counter.num_rows))
        left_bound_strict = bisect.bisect_left(selected_noun_phrase_counter['frequency'], args.noun_phrase_frequency_threshold_strict)
        noun_phrase_inverted_index_strict = noun_phrase_inverted_index.select(range(left_bound_strict, selected_noun_phrase_counter.num_rows))
        selected_noun_phrase_counter_strict = selected_noun_phrase_counter.select(range(left_bound_strict, selected_noun_phrase_counter.num_rows))
        args.noun_phrase_frequency_threshold = args.noun_phrase_frequency_threshold_strict
        print(f"strict noun_phrase_inverted_index: {noun_phrase_inverted_index_strict.num_rows} with noun_phrase_frequency_threshold_strict {args.noun_phrase_frequency_threshold_strict}")
        # strict noun_phrase_inverted_index: 2345223 with noun_phrase_frequency_threshold_strict 20
        noun_phrase_inverted_index = noun_phrase_inverted_index_strict
        selected_noun_phrase_counter = selected_noun_phrase_counter_strict
    print(f"selected_noun_phrase_counter: {selected_noun_phrase_counter}")
    print(f"noun_phrase_inverted_index: {noun_phrase_inverted_index}")

    # select urls
    selected_path = os.path.join(merged_and_deduplicated_path, f'filter_{args.noun_phrase_frequency_threshold}_{args.noun_phrase_select}')
    if not os.path.exists(selected_path):
        selected_url_index = set()
        for example in tqdm(noun_phrase_inverted_index, desc='select urls'):
            selected_url_index.update(example['url_index'][:args.noun_phrase_select])
        print(f"# selected urls: {len(selected_url_index)}")
        # selected urls: 66435027
        selected_url_index = list(selected_url_index)
        selected_url_index.sort()
        selected_dataset = merged_dataset.select(list(selected_url_index))
        selected_dataset.save_to_disk(selected_path, max_shard_size="20GB")
    else:
        selected_dataset = load_from_disk(selected_path)
    print(f"selected_dataset: {selected_dataset}")

    # save selected dataset without noun_phrase
    selected_dataset_clean_path = os.path.join(merged_and_deduplicated_path, f'filter_{args.noun_phrase_frequency_threshold}_{args.noun_phrase_select}.parquet')
    if not os.path.exists(selected_dataset_clean_path):
        selected_dataset.remove_columns('noun_chunks').to_parquet(selected_dataset_clean_path)
    
    # for drawing the frequency distribution
    # build noun_phrase counter for selected_dataset
    selected_noun_phrase_counter_path = os.path.join(selected_path, 'selected_noun_phrase_counter')
    if not os.path.exists(selected_noun_phrase_counter_path):
        # remove the un-downloaded urls from selected_dataset
        downloaded_dataset = load_from_disk(os.path.join(merged_and_deduplicated_path, 'image_token', 'merged_new'))
        downloaded_urls = set(downloaded_dataset['train']['url'] + downloaded_dataset['test']['url'])
        print(f"# downloaded urls: {len(downloaded_urls)}")
        # downloaded urls: 40135418
        del downloaded_dataset

        selected_dataset = selected_dataset.filter(
            lambda x: x['url'] in downloaded_urls,
            num_proc=args.process_num_workers,
            desc="Filter downloaded urls",
        )
        
        selected_noun_phrase_inverted_index = {k: 0 for k in selected_noun_phrase_counter["noun_phrase"]}

        def build_selected_noun_phrase_inverted_index(example):
            for noun_phrase in example['noun_chunks']:
                if noun_phrase in selected_noun_phrase_inverted_index:
                    selected_noun_phrase_inverted_index[noun_phrase] += 1

        selected_dataset.map(
            build_selected_noun_phrase_inverted_index,
            desc="Update selected inverted index",
        )

        selected_noun_phrase_counter = selected_noun_phrase_counter.map(
            lambda x: {"update_frequency": selected_noun_phrase_inverted_index[x['noun_phrase']]},
            num_proc=args.process_num_workers,
            desc="Update selected noun_phrase counter",
        )
        selected_noun_phrase_counter.save_to_disk(selected_noun_phrase_counter_path, max_shard_size="20GB")
    else:
        selected_noun_phrase_counter = load_from_disk(selected_noun_phrase_counter_path)
    print(f"# unique selected noun_chunks: {selected_noun_phrase_counter.num_rows}")
    # unique selected noun_chunks: 2345223

    selected_downloaded_noun_phrase_counter_path = os.path.join(selected_path, 'selected_downloaded_noun_phrase_counter')
    if not os.path.exists(selected_downloaded_noun_phrase_counter_path):
        selected_downloaded_noun_phrase_counter = selected_noun_phrase_counter.filter(
            lambda x: x['update_frequency'] > 0,
            num_proc=args.process_num_workers,
            desc="Filter downloaded noun_phrase",
        )
        selected_downloaded_noun_phrase_counter.save_to_disk(selected_downloaded_noun_phrase_counter_path, max_shard_size="20GB")
    else:
        selected_downloaded_noun_phrase_counter = load_from_disk(selected_downloaded_noun_phrase_counter_path)
    print(f"# unique selected downloaded noun_chunks: {selected_downloaded_noun_phrase_counter.num_rows}")
    # unique selected downloaded noun_chunks: 2326987
    
    # plot the noun_phrase frequency changes
    x = np.arange(100001)
    y1 = list(selected_noun_phrase_counter["frequency"])[-100001:][::-1]
    y2 = list(selected_noun_phrase_counter["update_frequency"])
    y2.sort(reverse=True)
    y2 = y2[:100001]
    print(y1[:100], y2[:100])
    # Plotting the data
    plt.figure(figsize=(10, 3))
    # len(noun_phrase_counter), len(selected_noun_phrase_counter)
    plt.plot(x, y1, label=f'Merged (>= 10): 4705117', color='#1F77B4', linewidth=2)
    plt.plot(x, y2, label=f'Merged (>= 20): 2326987', color='#FF7F0F', linewidth=2)
    plt.yscale('log')
    plt.xlim(x.min(), x.max())
    plt.xticks(list(range(0, 100001, 10000)), fontsize=14)
    plt.ylim(1, 10**7)
    plt.yticks([10**i for i in range(1, 8, 2)], fontsize=14)

    plt.xlabel('Unique noun-phrases (ordered by frequency in the descending order)', fontsize=14)
    plt.ylabel('Frequency', fontsize=14)
    plt.legend(loc='upper right', fontsize=14)
    plt.fill_between(x, y1, facecolor='#8FBBD9')
    plt.fill_between(x, y2, facecolor='#C69D74')
    plt.grid(True, axis='y', linestyle=(2, (5, 5)), linewidth=0.5)
    plt.savefig(os.path.join(selected_path, 'Frequency_Distribution_Coverage.pdf'), dpi=300, format="pdf")

if __name__ == "__main__":
    main()
    # test()