import os
import torch
import random
import numpy as np
from PIL import Image
from io import BytesIO
import datasets
import logging
from datetime import timedelta
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed, gather_object, InitProcessGroupKwargs
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import pandas as pd
import multiprocessing
import argparse
import transformers
from transformers import (
    default_data_collator,
)
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = get_logger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description='use LaVIT image tokenizer to tokenize image into discrete token id sequence')
    parser.add_argument('--project_name', type=str, default='MLLM', help='project name')
    parser.add_argument('--group_name', type=str, default='data process', help='group name')
    parser.add_argument('--run_name', type=str, default='tokenize images', help='run name')
    
    parser.add_argument('--model_path', type=str, default='YOUR_ROOT_PATH/model/LaVIT-7B-v2', help='path to LaVIT checkpoint')
    parser.add_argument('--output_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/IC', help='path to save the output')
    parser.add_argument('--src_path', type=str, default='YOUR_ROOT_PATH/MLLM/src', help='path to src code')
    parser.add_argument('--dataset_name', type=str, default='Merged', help='dataset name')
    parser.add_argument('--dataset_shard_index', type=int, default=-1, help='dataset shard index')
    parser.add_argument('--use_xformers', type=bool, default=True, help='use xformers')
    parser.add_argument('--mixed_precision', type=str, default='bf16', help='mixed precision')
    parser.add_argument('--report_to', type=str, default='wandb', help='report to')
    parser.add_argument("--with_tracking", action="store_true", help="Whether to enable experiment trackers for logging.")
    parser.add_argument('--img_size', type=int, default=224, help='image size')
    parser.add_argument("--aspect_ratio_threshold", type=int, default=3, help="threshold for aspect ratio")
    parser.add_argument('--max_image_length', type=int, default=256, help='max image length')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--image_pad_token_id', type=int, default=-1, help='pad token id')
    parser.add_argument('--process_batch_size', type=int, default=200, help='process batch size')
    parser.add_argument('--per_device_batch_size', type=int, default=256, help='per device batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    print('Number of available cores:', multiprocessing.cpu_count())
    print('Number of available gpus:', torch.cuda.device_count())
    
    try:
        print('GPU model name:', torch.cuda.get_device_name(0))
        print('GPU memory size:', torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024, 'GB')
    except:
        print('No GPU available.')
    
    args = parser.parse_args()
    
    return args

def main():
    args = parse_args()

    import sys
    sys.path.append(args.src_path)
    from LaVIT import build_model, build_dynamic_tokenizer, LaVITImageProcessor, convert_weights_to_bf16
    
    # for JourneyDB's slow data loading
    accelerator_kwargs = {}
    accelerator_kwargs["kwargs_handlers"] = [InitProcessGroupKwargs(timeout=timedelta(seconds=36000))]
    if args.with_tracking:
        accelerator_kwargs["log_with"] = args.report_to
        accelerator_kwargs["project_dir"] = args.output_dir

    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
        **accelerator_kwargs
    )

    logger.info(f'mixed_precision: {accelerator.mixed_precision}')

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
        else:
            logger.warning(
                "There is no `args.=output_dir` specified! Model checkpoints will not be saved."
            )
            exit()
        
    accelerator.wait_for_everyone()

    processor = LaVITImageProcessor(image_size=args.img_size)
    
    def extract_caption_from_webdataset_json(examples):
        for key in ['caption_origin', 'caption_coco', 'caption_capsfusion']:
            if key in examples['json'][0]:
                examples[key] = [json[key] for json in examples['json']]
                if key == 'caption_capsfusion': # avoid None
                    examples[key] = [caption if caption is not None else '' for caption in examples[key]]
        examples['url'] = [json['url'] for json in examples['json']]
        examples.pop('json')
        return examples

    if args.dataset_name == 'JourneyDB':
        images_dataset = DatasetDict()
        folder_count = 200
        folder_list = [str(i).zfill(3) for i in range(folder_count)]
        for split in ['train', 'valid']:
            for folder_index, folder_name in enumerate(folder_list):
                images_path = os.path.join(args.output_dir, args.dataset_name, f'data/{split}/imgs/{folder_name}')
                cur_image_dataset = load_dataset("imagefolder", data_dir=images_path)['train']
                if folder_index == 0:
                    images_dataset[split] = cur_image_dataset
                else:
                    images_dataset[split] = datasets.concatenate_datasets([images_dataset[split], cur_image_dataset])

        dataset_keys = ['url', 'caption', 'style']
        images_dataset = images_dataset.remove_columns('folder')
        accelerator.print(f"before filter: {images_dataset}")
        # 4189708, 234156
        
        def filter_by_size_and_aspect_ratio(example):
            try:
                w, h = example['image'].size
                return w >= args.img_size and h >= args.img_size and w / h <= args.aspect_ratio_threshold and h / w <= args.aspect_ratio_threshold
            except Exception as e:
                print(e)
                return False
        
        with accelerator.main_process_first():
            images_dataset = images_dataset.filter(
                filter_by_size_and_aspect_ratio,
                num_proc=args.process_num_workers,
                desc="Filtering by size and aspect ratio",
            )
            accelerator.print(f"after filter: {images_dataset}")
        # 4189622 234151

        train_dataset = images_dataset["train"]
        valid_dataset = images_dataset["valid"]
        accelerator.wait_for_everyone()

    else:
        with accelerator.main_process_first():
            images_path = os.path.join(args.output_dir, args.dataset_name, 'images')
            # remember to remove all not-tar files in the images_path
            if args.dataset_shard_index != -1:
                images_path = os.path.join(images_path, f'{args.dataset_shard_index}')
                file_list = os.listdir(images_path)
                file_list.sort(key=lambda x: int(x[:5]))
                
                images_dataset = load_dataset("webdataset", data_dir=images_path)
                
                # # if "KeyError: 'json'" error occurs, try to load the dataset one by one, and delete the error tar file.
                # images_dataset = None
                # for file_index, file_name in enumerate(file_list):
                #     if file_index < 300:
                #         continue
                #     try:
                #         cur_image_dataset = load_dataset("webdataset", data_files=os.path.join(images_path, file_name))['train']
                #         if file_index == 0 or images_dataset is None:
                #             images_dataset = cur_image_dataset
                #         else:
                #             images_dataset = datasets.concatenate_datasets([images_dataset, cur_image_dataset])
                #     except Exception as e:
                #         accelerator.print(f'[Error] skip {file_index}: {e}')
                    
                accelerator.print(images_dataset)
                if args.dataset_name == 'Merged_new':
                    dataset_keys = ['url', 'caption_origin', 'caption_coco', 'caption_capsfusion']
                else: # for laion-coco-aesthetic
                    dataset_keys = ['url', 'caption_origin', 'caption_coco']
                images_dataset = images_dataset.remove_columns(['__key__', '__url__'])
                images_dataset = images_dataset.rename_column('jpg', 'image')
                images_dataset = images_dataset.map(
                    extract_caption_from_webdataset_json,
                    batched=True,
                    num_proc=args.process_num_workers,
                    batch_size=args.process_batch_size,
                    remove_columns=['json'],
                    desc="Extracting caption from webdataset json",
                )
                accelerator.print(images_dataset)
            else:
                raise NotImplementedError
            train_dataset = images_dataset["train"]

    def custom_collate_fn(batch):
        batched_inputs = {k: [x[k] for x in batch] for k in dataset_keys}
        batched_inputs["image_tensor"] = torch.stack([processor(x['image'].convert("RGB")) for x in batch]).to(accelerator.device)
        return batched_inputs

    train_dataloader = DataLoader(
        train_dataset, shuffle=False, collate_fn=custom_collate_fn, batch_size=args.per_device_batch_size, drop_last=True
    )
    
    if args.dataset_name == 'JourneyDB':
        valid_dataloader = DataLoader(
            valid_dataset, shuffle=False, collate_fn=custom_collate_fn, batch_size=args.per_device_batch_size, drop_last=True
        )
    
    visual_tokenizer = build_dynamic_tokenizer(args.model_path, use_xformers=args.use_xformers, for_understanding=False).eval()
    # convert precision
    # convert_weights_to_bf16(visual_tokenizer)
    
    if args.dataset_name == 'JourneyDB':
        visual_tokenizer, train_dataloader, valid_dataloader = accelerator.prepare(
            visual_tokenizer, train_dataloader, valid_dataloader
        )
        dataloader_list = [train_dataloader, valid_dataloader]
        splits = ['train', 'valid']
    else:
        visual_tokenizer, train_dataloader = accelerator.prepare(
            visual_tokenizer, train_dataloader
        )
        dataloader_list = [train_dataloader]
        splits = ['train']

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if args.with_tracking:
        experiment_config = vars(args)
        accelerator.init_trackers(args.project_name, experiment_config, init_kwargs={"name": args.run_name, "group_name": args.group_name})

    total_batch_size = args.per_device_batch_size * accelerator.num_processes

    logger.info("***** Running tokenize images *****")
    logger.info(f"  Num train examples = {len(train_dataset)}")
    logger.info(f"  Total batch size = {total_batch_size}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_batch_size}")
    logger.info(f"  Total train steps = {len(train_dataloader)}")

    image_token_datasets = DatasetDict()
    for dataloader, split in zip(dataloader_list, splits):
        # Only show the progress bar once on each machine.
        progress_bar = tqdm(range(len(dataloader)), disable=not accelerator.is_local_main_process)
        completed_steps = 0
        progress_bar.update(completed_steps)
        
        outputs = dict.fromkeys(dataset_keys + ['image_tokens'])
        for key in outputs.keys():
            outputs[key] = []
        
        for step, batch in enumerate(dataloader):
            with accelerator.autocast(), torch.no_grad():
                image_token_list = list(visual_tokenizer.module.tokenize_image(batch['image_tensor'], add_special=False))
                image_tokens = torch.zeros(len(image_token_list), args.max_image_length, dtype=torch.long).fill_(args.image_pad_token_id).to(accelerator.device)
                for i, image_token in enumerate(image_token_list):
                    image_tokens[i, :len(image_token)] = image_token
                batch['image_tokens'] = image_tokens.cpu().numpy().tolist()
                batch.pop('image_tensor')
                for key in batch.keys():
                    outputs[key].extend(batch[key])

            if args.with_tracking:
                accelerator.log(
                    {
                        f"{split}_step": step
                    },
                    step=completed_steps,
                )
            
            progress_bar.update(1)
            completed_steps += 1
        
        accelerator.wait_for_everyone()
        
        # gather outputs
        gathered_outputs = dict.fromkeys(dataset_keys + ['image_tokens'])
        for key in outputs.keys():
            gathered_outputs[key] = gather_object(outputs[key])

        if accelerator.is_main_process:
            image_token_datasets[split] = Dataset.from_dict(gathered_outputs)

    if accelerator.is_main_process:
        output_path = os.path.join(args.output_dir, args.dataset_name, 'image_token')
        if args.dataset_shard_index != -1:
            output_path = os.path.join(output_path, f'{args.dataset_shard_index}')
        os.makedirs(output_path, exist_ok=True)
        image_token_datasets.save_to_disk(
            output_path, max_shard_size="20GB"
        )
        logger.info(f"Saving dataset to {output_path}")
        accelerator.print(image_token_datasets)
        
    if args.with_tracking:
        accelerator.end_training()

if __name__ == "__main__":
    main()
