import os
import torch
import random
import numpy as np
from PIL import Image
from io import BytesIO
import datasets
import logging
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import pandas as pd
import multiprocessing
from utils import dataset_name_split_mapping
import argparse
import transformers
from transformers import (
    default_data_collator,
)

logger = get_logger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description='use LaVIT image tokenizer to tokenize image into discrete token id sequence')
    parser.add_argument('--project_name', type=str, default='MLLM', help='project name')
    parser.add_argument('--group_name', type=str, default='data process', help='group name')
    parser.add_argument('--run_name', type=str, default='tokenize images', help='run name')
    parser.add_argument('--model_path', type=str, default='YOUR_ROOT_PATH/model/LaVIT-7B-v2', help='path to LaVIT checkpoint')
    parser.add_argument('--output_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/Evaluation', help='path to save the output')
    parser.add_argument('--src_path', type=str, default='YOUR_ROOT_PATH/MLLM/src', help='path to src code')
    parser.add_argument('--dataset_name', type=str, default='MMBench_DEV_EN', help='dataset name')
    parser.add_argument('--dataset_type', type=str, default='files', help='dataset type', choices=['files', 'PIL'])
    parser.add_argument('--use_xformers', type=bool, default=True, help='use xformers')
    parser.add_argument('--mixed_precision', type=str, default='bf16', help='mixed precision')
    parser.add_argument('--report_to', type=str, default='wandb', help='report to')
    parser.add_argument("--with_tracking", action="store_true", help="Whether to enable experiment trackers for logging.")
    parser.add_argument('--img_size', type=int, default=224, help='image size')
    parser.add_argument('--max_image_length', type=int, default=256, help='max image length')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--image_pad_token_id', type=int, default=-1, help='pad token id')
    parser.add_argument('--process_batch_size', type=int, default=200, help='process batch size')
    parser.add_argument('--per_device_batch_size', type=int, default=256, help='per device batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    print('Number of available cores:', multiprocessing.cpu_count())
    print('Number of available gpus:', torch.cuda.device_count())
    
    try:
        print('GPU model name:', torch.cuda.get_device_name(0))
        print('GPU memory size:', torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024, 'GB')
    except:
        print('No GPU available.')
    
    args = parser.parse_args()
    
    return args

def main():
    args = parse_args()

    import sys
    sys.path.append(args.src_path)
    from LaVIT import build_model, build_dynamic_tokenizer, LaVITImageProcessor, convert_weights_to_bf16
    
    accelerator_kwargs = {}
    if args.with_tracking:
        accelerator_kwargs["log_with"] = args.report_to
        accelerator_kwargs["project_dir"] = args.output_dir

    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
        **accelerator_kwargs
    )

    logger.info(f'mixed_precision: {accelerator.mixed_precision}')

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
        else:
            logger.warning(
                "There is no `args.=output_dir` specified! Model checkpoints will not be saved."
            )
            exit()
        
    accelerator.wait_for_everyone()

    processor = LaVITImageProcessor(image_size=args.img_size)
    
    def transforms(examples):
        if args.dataset_name in ['MathVista_MultiChoice', 'MathVista_OpenEnded']:
            examples["image_tensor"] = [processor(img.convert("RGB")) for img in examples["decoded_image"]]
            del examples["decoded_image"]
        else:
            examples["image_tensor"] = [processor(img.convert("RGB")) for img in examples["image"]]
            del examples["image"]
        return examples

    def mmmu_image_transform(examples):
        new_examples = {}
        new_examples['image'] = []
        for idx in range(len(examples['image_1'])):
            for i in range(1, 8):
                if examples[f'image_{i}'][idx] is not None:
                    new_examples['image'].append(examples[f'image_{i}'][idx])
        return new_examples

    with accelerator.main_process_first():
        if args.dataset_type == 'files':
            images_path = os.path.join(args.output_dir, args.dataset_name, 'images')
            images_dataset = load_dataset("imagefolder", data_dir=images_path)
            images_dataset = images_dataset.with_transform(transforms)
            train_dataset = images_dataset["train"]
        elif args.dataset_type == 'PIL':
            origin_dataset = os.path.join(args.output_dir, args.dataset_name)
            images_dataset = load_from_disk(origin_dataset)[dataset_name_split_mapping[args.dataset_name]]
            if args.dataset_name in ['MMMU_VAL_MultiChoice', 'MMMU_TEST_MultiChoice', 'MMMU_VAL_OpenEnded', 'MMMU_TEST_OpenEnded']:
                images_dataset = images_dataset.map(mmmu_image_transform, batched=True, batch_size=args.process_batch_size, num_proc=args.process_num_workers, remove_columns=images_dataset.column_names)
            else:
                images_dataset = images_dataset.remove_columns([name for name in images_dataset.column_names if name not in ['image', 'decoded_image']])
            images_dataset = images_dataset.with_transform(transforms)
            train_dataset = images_dataset
        else:
            raise NotImplementedError

    train_dataloader = DataLoader(
        train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_batch_size
    )
    
    visual_tokenizer = build_dynamic_tokenizer(args.model_path, use_xformers=args.use_xformers, for_understanding=False).eval()
    # convert precision
    # convert_weights_to_bf16(visual_tokenizer)
    
    visual_tokenizer, train_dataloader = accelerator.prepare(
        visual_tokenizer, train_dataloader
    )

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if args.with_tracking:
        experiment_config = vars(args)
        accelerator.init_trackers(args.project_name, experiment_config, init_kwargs={"name": args.run_name, "group_name": args.group_name})

    total_batch_size = args.per_device_batch_size * accelerator.num_processes

    logger.info("***** Running tokenize images *****")
    logger.info(f"  Num train examples = {len(train_dataset)}")
    logger.info(f"  Total batch size = {total_batch_size}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_batch_size}")
    logger.info(f"  Total train steps = {len(train_dataloader)}")

    image_token_datasets = DatasetDict()
    for dataloader, split in zip([train_dataloader], ['train']):
        # Only show the progress bar once on each machine.
        progress_bar = tqdm(range(len(dataloader)), disable=not accelerator.is_local_main_process)
        completed_steps = 0
        progress_bar.update(completed_steps)
        
        outputs = dict.fromkeys(['image_tokens'])
        # outputs = dict.fromkeys(['image_tokens', 'example_id']) # for few-shot
        for key in outputs.keys():
            outputs[key] = []
        
        for step, batch in enumerate(dataloader):
            with accelerator.autocast(), torch.no_grad():
                image_token_list = list(visual_tokenizer.module.tokenize_image(batch['image_tensor'], add_special=False))
                image_tokens = torch.zeros(len(image_token_list), args.max_image_length, dtype=torch.long).fill_(args.image_pad_token_id).to(accelerator.device)
                for i, image_token in enumerate(image_token_list):
                    image_tokens[i, :len(image_token)] = image_token
                batch['image_tokens'] = image_tokens
                batch.pop('image_tensor')
                batch = accelerator.gather_for_metrics(batch)
                # or use accelerate.utils.gather_object to gather, no need to be tensor, need to be picklable object, but gather_for_metrics() method can automatically remove the duplicated data 
                for key in batch.keys():
                    outputs[key].extend(batch[key].cpu().numpy().tolist())

            if args.with_tracking:
                accelerator.log(
                    {
                        f"{split}_step": step
                    },
                    step=completed_steps,
                )
            
            progress_bar.update(1)
            completed_steps += 1
        
        accelerator.wait_for_everyone()
        
        if args.output_dir is not None and accelerator.is_main_process:
            image_token_datasets[split] = Dataset.from_dict(outputs)

    if accelerator.is_main_process:
        output_path = os.path.join(args.output_dir, args.dataset_name, 'image_token')
        os.makedirs(output_path, exist_ok=True)
        image_token_datasets.save_to_disk(
            output_path, max_shard_size="20GB"
        )
        logger.info(f"Saving dataset to {output_path}")
        accelerator.print(image_token_datasets)
        
    if args.with_tracking:
        accelerator.end_training()

if __name__ == "__main__":
    main()
