import os
import torch
import random
import numpy as np
from PIL import Image
from io import BytesIO
import datasets
import logging
from datasets import load_dataset, load_from_disk, DatasetDict
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import pandas as pd
import multiprocessing
from itertools import chain
import json
import bisect
import argparse
import transformers
from transformers import (
    AutoTokenizer,
    LlamaTokenizer,
    default_data_collator,
)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--project_name', type=str, default='MLLM', help='project name')
    parser.add_argument('--model_path', type=str, default='YOUR_ROOT_PATH/model/llama2-1229/Llama-2-7b-hf', help='path to LaVIT checkpoint')
    parser.add_argument('--dataset_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/IC', help='path to origin caption dir')
    parser.add_argument('--dataset_name', type=str, default='Merged_new', help='dataset name')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--image_pad_token_id', type=int, default=-1, help='pad token id')
    parser.add_argument('--image_start_token_id', type=int, default=32000, help='image start token id')
    parser.add_argument('--image_end_token_id', type=int, default=32001, help='image end token id')
    parser.add_argument('--max_length', type=int, default=2048, help='max length')
    parser.add_argument('--max_main_part_length', type=int, default=512, help='max main part length')
    parser.add_argument('--max_related_part_length', type=int, default=512, help='max related part length')
    parser.add_argument('--image_first_prob', type=float, default=0.5, help='image first prob')
    parser.add_argument('--denoise_prob', type=float, default=0.0, help='denoise prob')
    parser.add_argument('--process_batch_size', type=int, default=1000, help='process batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    parser.add_argument('--sample_time', type=int, default=1, help='sample time of the dataset')
    
    print('Number of available cores:', multiprocessing.cpu_count())
    args = parser.parse_args()

    return args

def statistics(dataset_dict, dataset_name):
    for split in ['train', 'dev']:
        concept_count_mean = np.mean(dataset_dict[split]['concept_count'])
        concept_count_max = np.max(dataset_dict[split]['concept_count'])
        concept_count_min = np.min(dataset_dict[split]['concept_count'])
        image_token_length_mean = np.mean(dataset_dict[split]['image_token_length'])
        image_token_length_max = np.max(dataset_dict[split]['image_token_length'])
        image_token_length_min = np.min(dataset_dict[split]['image_token_length'])
        text_token_length_mean = np.mean(dataset_dict[split]['text_token_length'])
        text_token_length_max = np.max(dataset_dict[split]['text_token_length'])
        text_token_length_min = np.min(dataset_dict[split]['text_token_length'])
        print(f"###############{dataset_name}_{split}: {dataset_dict[split].num_rows}###############")
        print(f"concept_count mean: {concept_count_mean:.2f}, image_token_length mean: {image_token_length_mean:.2f}, text_token_length mean: {text_token_length_mean:.2f}, image_token_length per concept mean: {image_token_length_mean / concept_count_mean:.2f}, text_token_length per concept mean: {text_token_length_mean / concept_count_mean:.2f}")
        print(f"concept_count max: {concept_count_max}, image_token_length max: {image_token_length_max}, text_token_length max: {text_token_length_max}")
        print(f"concept_count min: {concept_count_min}, image_token_length min: {image_token_length_min}, text_token_length min: {text_token_length_min}")
    
def test():
    args = parse_args()
    for turn in range(args.sample_time):
        turn_dataset_path = os.path.join(args.dataset_dir, args.dataset_name, f"turn_{turn}")
        final_dataset = load_from_disk(turn_dataset_path)
        statistics(final_dataset, f"{args.dataset_name}/turn_{turn}")
        
        # llama_tokenizer = LlamaTokenizer.from_pretrained(args.model_path, subfolder='language_model', use_fast=False)
        tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, legacy=False)
        tokenizer.pad_token = tokenizer.eos_token

        for split in ['train', 'dev']:
            row_index = random.randint(0, final_dataset[split].num_rows)
            url, combined_desc, concept_count = final_dataset[split][row_index]['url'], final_dataset[split][row_index]['combined_desc'], final_dataset[split][row_index]['concept_count']
            
            image_token_lists = []
            text_token_lists = []
            image_start = False
            text_start = False
            for token_id in combined_desc:
                if token_id >= args.image_start_token_id:
                    if not image_start:
                        image_token_list = []
                        image_start = True
                    if text_start:
                        text_token_lists.append(text_token_list)
                        text_start = False
                    image_token_list.append(token_id)
                else:
                    if not text_start:
                        text_token_list = []
                        text_start = True
                    if image_start:
                        image_token_lists.append(image_token_list)
                        image_start = False
                    text_token_list.append(token_id)

            if image_start:
                image_token_lists.append(image_token_list)
            if text_start:
                text_token_lists.append(text_token_list)
            if image_start and text_start:
                raise ValueError("image_start and text_start are both True")

            text_tokens = [tokenizer.decode(text_token_list) for text_token_list in text_token_lists]
            text_tokens = "[image]".join(text_tokens)
            print(f"###############{args.dataset_name}/turn_{turn}_{split}, ###############")
            print(f"row_index: {row_index}, concept_count: {concept_count}, text_token_lists: {len(text_token_lists)}, image_token_lists: {len(image_token_lists)}")
            print(text_tokens)
            print(image_token_lists[0])
            print(url)

def main():
    args = parse_args()
    
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_info()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # llama_tokenizer = LlamaTokenizer.from_pretrained(args.model_path, subfolder='language_model', use_fast=False)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, legacy=False)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens({'additional_special_tokens': ['<image>']}, replace_additional_special_tokens=False)
    tokenizer.add_special_tokens({'additional_special_tokens': ['</image>']}, replace_additional_special_tokens=False)
    # tokenizer.add_tokens([f"<image_{str(i)}>" for i in range(16384)]) # 48386-32000-2
    image_start_token = tokenizer.additional_special_tokens[0]
    image_start_token_id = tokenizer.additional_special_tokens_ids[0]
    assert image_start_token_id == args.image_start_token_id

    def tokenize_by_llama(inputs):
        return tokenizer(inputs, add_special_tokens=False)["input_ids"]
    
    ic_template = {
        "image_first":{
            "introduction_prefix": "Image: ",
            "caption_prefix": "\nCaption: ",
            "detailed_caption_prefix": "\nDetailed caption: ",
            "style_prefix": "\nStyle: ",
        },
        "text_first":{
            "caption_prefix": "Caption: ",
            "detailed_caption_prefix": "\nDetailed caption: ",
            "style_prefix": "\nStyle: ",
            "image_prefix": "\nImage: ",
        }
    }

    # if args.dataset_name == "JourneyDB":
    #     for key in ic_template.keys():
    #         ic_template[key]["caption_prefix"] = ic_template[key]["detailed_caption_prefix"]
    #         ic_template[key]["detailed_caption_prefix"] = ic_template[key]["style_prefix"]
        

    def prompt_and_tokenize_function(examples, idxes):
        image_token_lists = examples["image_tokens"]
        if args.dataset_name != "JourneyDB":
            image_coco_captions = [caption if caption is not None else "" for caption in examples["caption_coco"]]
            image_capsfusion_captions = [caption if caption is not None else "" for caption in examples["caption_capsfusion"]]
            image_styles = ["" for i in range(len(image_capsfusion_captions))]
            examples.pop('caption_origin')
            examples.pop('caption_coco')
            examples.pop('caption_capsfusion')
        else:
            image_coco_captions = [caption if caption is not None else "" for caption in examples["caption_coco"]]
            image_capsfusion_captions = [caption if caption is not None else "" for caption in examples["caption_gpt"]]
            image_styles = [style if style is not None else "" for style in examples["style"]]
            examples.pop('caption_coco')
            examples.pop('caption_gpt')
            examples.pop('style')
        examples.pop("image_tokens")

        combined_descs, concept_counts, image_token_length_list, text_token_length_list, image_first_list = [], [], [], [], []
        batch_index = 0
        drop_index = []
        for image_token_list, image_coco_caption, image_capsfusion_caption, image_style in zip(image_token_lists, image_coco_captions, image_capsfusion_captions, image_styles):
            image_token_list = image_token_list[:image_token_list.index(args.image_pad_token_id)]
            first_type = ["image_first", "text_first"][random.random() < args.image_first_prob]
            template_cur = ic_template[first_type]

            combined_desc_dict = {
                "image": image_token_list,
                "image_coco_caption": image_coco_caption,
                "image_capsfusion_caption": image_capsfusion_caption,
                "image_style": image_style,
            }
            if first_type == "image_first":
                cur_combined_desc = f"{template_cur['introduction_prefix']}{image_start_token}"
                if len(combined_desc_dict['image_coco_caption']) != 0:
                    cur_combined_desc += f"{template_cur['caption_prefix']}{combined_desc_dict['image_coco_caption']}"
                else:
                    # raise ValueError("image_first should have image_coco_caption")
                    drop_index.append(batch_index)
                    print(f"image_first of {idxes[batch_index]}-th row should have image_coco_caption")
                    batch_index += 1
                    continue
                
                if random.random() >= args.denoise_prob:
                    if len(combined_desc_dict['image_capsfusion_caption']) != 0:
                        cur_combined_desc += f"{template_cur['detailed_caption_prefix']}{combined_desc_dict['image_capsfusion_caption']}"
                    
                    if len(combined_desc_dict['image_style']) != 0:
                        cur_combined_desc += f"{template_cur['style_prefix']}{combined_desc_dict['image_style']}"

            else:
                if len(combined_desc_dict['image_coco_caption']) != 0:
                    cur_combined_desc = f"{template_cur['caption_prefix']}{combined_desc_dict['image_coco_caption']}"
                else:
                    # raise ValueError("text_first should have image_coco_caption")
                    drop_index.append(batch_index)
                    print(f"text_first of {idxes[batch_index]}-th row should have image_coco_caption")
                    batch_index += 1
                    continue
                
                if random.random() >= args.denoise_prob:
                    if len(combined_desc_dict['image_capsfusion_caption']) != 0:
                        cur_combined_desc += f"{template_cur['detailed_caption_prefix']}{combined_desc_dict['image_capsfusion_caption']}"

                    if len(combined_desc_dict['image_style']) != 0:
                        cur_combined_desc += f"{template_cur['style_prefix']}{combined_desc_dict['image_style']}"
                    
                cur_combined_desc += f"{template_cur['image_prefix']}{image_start_token}"

            combined_desc_tokens = tokenize_by_llama(cur_combined_desc)
            combined_desc_len = len(combined_desc_tokens) + len(combined_desc_dict["image"]) + 1
            if combined_desc_len > args.max_length - 2:
                continue
            
            # replace image_start_token with image_token_list
            image_position_index = combined_desc_tokens.index(args.image_start_token_id)
            combined_desc_tokens = combined_desc_tokens[:image_position_index + 1] + combined_desc_dict["image"] + [args.image_end_token_id] + combined_desc_tokens[image_position_index + 1:]
            
            combined_desc = [tokenizer.bos_token_id] + combined_desc_tokens[:args.max_length-2] + [tokenizer.eos_token_id]
            combined_descs.append(combined_desc)

            image_first_list.append(1 if first_type == "image_first" else 0)
            # statistics
            combined_desc = np.array(combined_desc)
            concept_count = int(np.sum(combined_desc == args.image_end_token_id))
            concept_counts.append(concept_count)
            # we don't count special tokens
            image_token_length_list.append(int(np.sum(combined_desc > args.image_end_token_id)))
            text_token_length_list.append(int(np.sum(combined_desc < args.image_start_token_id)) - 2)
            
            batch_index += 1
            
        examples["combined_desc"] = combined_descs
        examples["concept_count"] = concept_counts
        examples["image_token_length"] = image_token_length_list
        examples["text_token_length"] = text_token_length_list
        examples["image_first"] = image_first_list

        for key in ['url']: # + remove_columns:
            for d_index in drop_index[::-1]:
                del examples[key][d_index]
        
        assert (len(image_token_lists) - len(drop_index)) == len(examples['url']) == len(examples["concept_count"])

        return examples

    splits = ['train', 'dev']
    if args.dataset_name in ["Merged_new", "laion-coco-aesthetic"]:
        origin_dataset = load_from_disk(os.path.join(args.dataset_dir, args.dataset_name, 'image_token', 'merged_new'))
        origin_splits = ['train', 'test']
        remove_columns = ['image_tokens', 'caption_origin', 'caption_coco', 'caption_capsfusion']
    elif args.dataset_name == "JourneyDB":
        origin_dataset = load_from_disk(os.path.join(args.dataset_dir, args.dataset_name, 'image_token_new'))
        origin_splits = ['train', 'valid']
        remove_columns = ['image_tokens', 'caption_gpt', 'caption_coco', 'style']
    else:
        raise ValueError("Unknown dataset name")
    
    for turn in range(args.sample_time):
        combined_dataset = DatasetDict()
        origin_dataset.cleanup_cache_files()
        for split, origin_split in zip(splits, origin_splits):
            combined_dataset[split] = origin_dataset[origin_split].map(
                prompt_and_tokenize_function,
                batched=True,
                with_indices=True,
                batch_size=args.process_batch_size,
                num_proc=args.process_num_workers,
                remove_columns=remove_columns,
                desc="Construct final dataset",
            )

        turn_dataset_path = os.path.join(args.dataset_dir, args.dataset_name, f"turn_{turn}_{args.denoise_prob}")
        os.makedirs(turn_dataset_path, exist_ok=True)
        combined_dataset.save_to_disk(turn_dataset_path, max_shard_size="20GB")
        statistics(combined_dataset, f"{args.dataset_name}/turn_{turn}")
        print(combined_dataset)
        
        
if __name__ == "__main__":
    main()
    test()
