import os
import torch
import random
import numpy as np
from PIL import Image
from io import BytesIO
import datasets
import logging
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset
from collections import Counter
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import pandas as pd
import multiprocessing
from itertools import chain
import json
import bisect
import string

import spacy
from spacy.lang.en import EnglishDefaults
from typing import Iterator, Tuple, Union

from spacy.errors import Errors
from spacy.symbols import NOUN, PRON, PROPN
from spacy.tokens import Doc, Span
import argparse
import transformers
from transformers import (
    AutoTokenizer,
    LlamaTokenizer,
    default_data_collator,
)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='YOUR_ROOT_PATH/model/llama2-1229/Llama-2-7b-hf', help='path to LaVIT checkpoint')
    
    parser.add_argument('--dataset_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/IC', help='path to dataset')
    parser.add_argument('--dataset_name', type=str, default='CapsFusion-120M', help='dataset name')
    parser.add_argument('--dataset_type', type=str, default='parquet', help='dataset file type')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--process_batch_size', type=int, default=200, help='process batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    print('Number of available cores:', multiprocessing.cpu_count())
    args = parser.parse_args()


    
    return args

def recursive_noun_end(word, end):
    # Extend noun phrase to include prepositional phrases
    for child in word.children:
        if child.dep_ in ["pobj", "pcomp"]:
            end = max(child.i, end)
        if child.dep_ in ["pobj", "pcomp", "prep"]:
            end = max(recursive_noun_end(child, end), end)
    return end

def custom_noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
    """
    Detect base noun phrases from a dependency parse. Works on both Doc and Span.
    Modified from https://github.com/explosion/spaCy/blob/master/spacy/lang/en/syntax_iterators.py
    Now it can recursively include prepositional phrases in noun chunks.
    """
    # https://spacy.io/models/en#en_core_web_sm-labels
    # ROOT, acl, acomp, advcl, advmod, agent, amod, appos, attr, aux, auxpass, case, cc, ccomp, compound, conj, csubj, csubjpass, dative, dep, det, dobj, expl, intj, mark, meta, neg, nmod, npadvmod, nsubj, nsubjpass, nummod, oprd, parataxis, pcomp, pobj, poss, preconj, predet, prep, prt, punct, quantmod, relcl, xcomp
    # https://github.com/clir/clearnlp-guidelines/blob/master/md/specifications/dependency_labels.md
    # https://universaldependencies.org/u/dep/
    # https://universaldependencies.org/u/pos/
    
    labels = [
        "oprd",
        "nsubj",
        "dobj",
        "nsubjpass",
        "pcomp",
        "pobj",
        "dative",
        "appos",
        "attr",
        "ROOT",
        "prep",
        "relcl",
    ]
    doc = doclike.doc  # Ensure works on both Doc and Span.
    if not doc.has_annotation("DEP"):
        raise ValueError(Errors.E029)
    np_deps = [doc.vocab.strings.add(label) for label in labels]
    conj = doc.vocab.strings.add("conj")
    np_label = doc.vocab.strings.add("NP")
    prev_end = -1
    for i, word in enumerate(doclike):
        # we remove the pronoun here
        # if word.pos not in (NOUN, PROPN, PRON):
        if word.pos not in (NOUN, PROPN):
            continue
        # Prevent nested chunks from being produced
        if word.left_edge.i <= prev_end:
            continue
        if word.dep in np_deps:
            prev_end = word.i
            yield word.left_edge.i, word.i + 1, np_label
        elif word.dep == conj:
            head = word.head
            while head.dep == conj and head.head.i < head.i:
                head = head.head
            # If the head is an NP, and we're coordinated to it, we're an NP
            if head.dep in np_deps:
                prev_end = word.i
                yield word.left_edge.i, word.i + 1, np_label
    
    # for recursive noun chunks by prepositional phrases
    for i, word in enumerate(doclike):
        if word.pos not in (NOUN, PROPN):
            continue
        if word.dep in np_deps:
            recusive_end = recursive_noun_end(word, word.i)
            if recusive_end > word.i:
                yield word.left_edge.i, recusive_end + 1, np_label

def print_noun_chunks(docs):
    for doc in docs:
        print("========================================")
        print(doc.text)
        for token in doc.noun_chunks:
            # print(token.text)
            print(token.lemma_)

def get_spacy_noun_extractor():
    EnglishDefaults.syntax_iterators = {"noun_chunks": custom_noun_chunks}
    nlp = spacy.load("en_core_web_lg", exclude=["ner"])
    print(nlp.pipe_names)
    return nlp

def test():
    extractor = get_spacy_noun_extractor()
    print_noun_chunks(extractor.pipe([
        "A black cat is chasing a small brown bird",
        "a dog in a field of flowers and grasses",
        "a wire hanger with a paper cover that reads we heart our customers",
        "Autonomous cars shift insurance liability toward manufacturers",
        "You know some birds are not meant to be caged,their feathers are just too bright.",
        "Messi with Argentina at the 2022 FIFA World Cup",
        "This is my bag",
        "dog and cat are good friends of human",
        ]))

def main():
    args = parse_args()
    
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_info()
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # llama_tokenizer = LlamaTokenizer.from_pretrained(args.model_path, subfolder='language_model', use_fast=False)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, legacy=False)
    tokenizer.pad_token = tokenizer.eos_token

    extractor = get_spacy_noun_extractor()

    def extract_noun_chunks(texts):
        return [list(set([token.lemma_ for token in doc.noun_chunks])) for doc in extractor.pipe(texts)]
    
    def extract_noun_chunks_function(examples, caption_key="caption"):
        examples["noun_chunks"] = extract_noun_chunks(examples[caption_key])
        return examples

    datasets_path = os.path.join(args.dataset_dir, args.dataset_name, 'datasets_new')
    if not os.path.exists(datasets_path):
        if args.dataset_name == 'CapsFusion-120M':
            cf_dataset = load_dataset(
                args.dataset_type, 
                data_files=[
                    os.path.join(args.dataset_dir, args.dataset_name, f'capsfusion_{i}.parquet') for i in range(1, 5)
                ],
            )
            print(f"Origin column names: {cf_dataset['train'].column_names}")
            remove_columns_names = [name for name in cf_dataset['train'].column_names if name not in ['image_url', 'capsfusion', 'laion_2b', 'laion_coco', 'noun_chunks']]

            cf_dataset['train'] = cf_dataset['train'].map(
                lambda examples: extract_noun_chunks_function(examples, caption_key="capsfusion"),
                batched=True,
                batch_size=args.process_batch_size,
                num_proc=args.process_num_workers,
                remove_columns=remove_columns_names,
                desc="Extract noun chunks",
            )
            cf_dataset = cf_dataset.rename_columns({'laion_2b': 'caption_origin', 'laion_coco': 'caption_coco', 'capsfusion': 'caption_capsfusion', 'image_url': 'url'})
            print(f"Final column names: {cf_dataset['train'].column_names}")
            cf_dataset.save_to_disk(datasets_path, max_shard_size="20GB")
            dataset = cf_dataset
        elif args.dataset_name == "BLIP":
            blip_dataset, blip_dataset_origin = DatasetDict(), DatasetDict()
            for json_path in os.listdir(os.path.join(args.dataset_dir, args.dataset_name)):
                if not json_path.endswith('.json'):
                    continue
                split_name = json_path.split('_')[0]
                with open(os.path.join(args.dataset_dir, args.dataset_name, json_path), 'r') as f:
                    json_data_list = json.load(f)
                if 'synthetic' in json_path:
                    blip_dataset[split_name] = Dataset.from_list(json_data_list)
                else:
                    blip_dataset_origin[split_name] = Dataset.from_list(json_data_list)
                del json_data_list
            print(f"Origin column names: {blip_dataset['ccs'].column_names}")

            # merge blip_dataset and blip_dataset_origin
            blip_dataset_origin = blip_dataset_origin.rename_column('caption', 'caption_origin')
            blip_dataset = blip_dataset.rename_column('caption', 'caption_coco')
            for split_name in blip_dataset:
                blip_dataset[split_name] = Dataset.from_pandas(
                    pd.merge(
                        blip_dataset[split_name].to_pandas().drop_duplicates(subset=['url'], keep='first', ignore_index=True),
                        blip_dataset_origin[split_name].to_pandas().drop_duplicates(subset=['url'], keep='first', ignore_index=True),
                        on=['url'],
                        how='inner',
                        validate='one_to_one'
                    )
                )
            blip_dataset = blip_dataset.remove_columns(['__index_level_0__'])
            del blip_dataset_origin
            # extract noun chunks
            for split_name in blip_dataset:
                blip_dataset[split_name] = blip_dataset[split_name].map(
                    lambda examples: extract_noun_chunks_function(examples, caption_key="caption_coco"),
                    batched=True,
                    batch_size=args.process_batch_size,
                    num_proc=args.process_num_workers,
                    desc="Extract noun chunks",
                )
                # blip_dataset[split_name] = blip_dataset[split_name].add_column('caption_capsfusion', [''] * len(blip_dataset[split_name]))
            print(f"Final column names: {blip_dataset['ccs'].column_names}")
            blip_dataset.save_to_disk(datasets_path, max_shard_size="20GB")
            dataset = blip_dataset
        else:
            raise NotImplementedError
    else:
        dataset = load_from_disk(datasets_path)
    print(dataset)
    # 11774713 96990957
    # ['caption_coco', 'url', 'caption_origin', 'noun_chunks']


if __name__ == "__main__":
    main()
    # test()