import torch
import sys
sys.path.insert(0, "../")

from inference.mm_utils import process_images, select_best_resolution
from inference.builder import load_pretrained_model

# vision_model_name = 'Lin-Chen/open-llava-next-vicuna-7b'
vision_model_name = 'liuhaotian/llava-v1.6-vicuna-7b'
cache_dir = None
# Avoid device_map="auto" which can place modules on the meta device under Accelerate.
# Load fully and move to CUDA with desired dtype inside the builder.
tokenizer, model, image_processor, context_len = load_pretrained_model(
    vision_model_name,
    cache_dir,
    device_map=None,
    torch_dtype=torch.bfloat16,
)

import torch
import numpy as np

def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
    if not isinstance(grid_pinpoints, list):
        raise TypeError("grid_pinpoints should be a list of tuples or lists")

    # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
    if not isinstance(image_size, (list, tuple)):
        if not isinstance(image_size, (torch.Tensor, np.ndarray)):
            raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
        image_size = image_size.tolist()

    best_resolution = select_best_resolution(image_size, grid_pinpoints)
    height, width = best_resolution
    num_patches = 0
    # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
    for i in range(0, height, patch_size):
        for j in range(0, width, patch_size):
            num_patches += 1
    # add the base patch
    num_patches += 1
    return num_patches

def get_image_features(
        llm,
        pixel_values: torch.FloatTensor,
        image_sizes: torch.Tensor
    ):
        image_num_patches = [
            image_size_to_num_patches(
                image_size=imsize,
                grid_pinpoints=llm.config.image_grid_pinpoints,
                patch_size=llm.get_vision_tower().config.image_size,
            )
            for imsize in image_sizes
        ]

        numbers = [np.random.randint(0, num_patches) for num_patches in image_num_patches]
        _pixel_values_list = [pix_val[number:number+1] for number, pix_val in zip(numbers, pixel_values)]
        pixel_values = torch.cat(_pixel_values_list, dim=0)

        image_features = llm.get_vision_tower()(pixel_values)
        return image_features, numbers

import os
import json
from tqdm import tqdm
from PIL import Image

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, default='train', choices=['train', 'val'])
parser.add_argument("--limit_samples", type=int, default=None)
args = parser.parse_args()
mode = args.mode

limit_samples = args.limit_samples

train_max_images = [200000] #, 9000, 9000]
# train_max_images = [10] #, 9000, 9000]
val_max_images = [5000 ] # , 1000, 1000]
# val_max_images = [10 ] # , 1000, 1000]

if limit_samples is not None:
    train_max_images = [limit_samples]
    val_max_images = [limit_samples]

# train_max_images = [10]

datasets_dir = '../../../data/'
batch_size = 16

vision_model_name_for_path = vision_model_name.split('/')[-1]
train_images_dir = [
    '../../../data/coco_subsets/train2017', # len of train2017 = 115404
    # '../../data/docvqa/train/documents', # len of docvqa = 10197
    # '../../data/chart_qa_images'
]

train_features_dir = f'{datasets_dir}/{vision_model_name_for_path}_mlp/tensors'
train_features_json = f'{datasets_dir}/{vision_model_name_for_path}_mlp/map.json'

val_images_dir = [
    '../../../data/coco_subsets/val2017', # len of val2017 = 4515
    # '../../data/docvqa/train/documents', # len of docvqa = 10197
    # '../../data/chart_qa_images'
]
val_features_dir = f'{datasets_dir}/{vision_model_name_for_path}_mlp/tensors_val'
val_features_json = f'{datasets_dir}/{vision_model_name_for_path}_mlp/map_val.json'

feature_image_map = {}

if mode == 'train':
    print("Generating train features", train_features_json)
    max_images_list = train_max_images
    images_dir_list = train_images_dir
    features_dir = train_features_dir
    features_json = train_features_json
    image_names_list = [os.listdir(images_dir)[:max_images] for images_dir, max_images in zip(images_dir_list, max_images_list)]
    os.makedirs(features_dir, mode=0o777, exist_ok=True)
elif mode == 'val':
    print("Generating val features", val_features_json)
    max_images_list = val_max_images
    images_dir_list = val_images_dir
    features_dir = val_features_dir
    features_json = val_features_json
    image_names_list = [os.listdir(images_dir)[-max_images:] for images_dir, max_images in zip(images_dir_list, max_images_list)]
    os.makedirs(features_dir, mode=0o777, exist_ok=True)
else:
    raise Exception

with torch.inference_mode(), torch.no_grad():
    for (image_names, images_dir) in zip(image_names_list, images_dir_list):
        for i in tqdm(range(0, len(image_names), batch_size)):
            batch_image_names, batch_image_paths, batch_images, batch_sizes = image_names[i:i+batch_size], [], [], []
            batch_feature_names, batch_feature_paths = [], []
            batch_crop_numbers = []

            for image_name in batch_image_names:
                feature_name = image_name.split('.')[0]
                image_path = os.path.join(images_dir, image_name)

                try:
                    example = Image.open(image_path).convert('RGB')
                    batch_images.append(example)
                    batch_sizes.append(example.size)
                    batch_image_paths.append(image_path)
                    batch_feature_names.append(feature_name)
                except Exception as e:
                    print(f"Error processing image {image_path}: {e}")
                    continue

            inputs = process_images(batch_images, image_processor, model.config)
            inputs = [inp.to(model.device, dtype=torch.bfloat16) for inp in inputs]
            batch_features, batch_crop_numbers = get_image_features(model, inputs, batch_sizes)
            batch_feature_paths = [os.path.join(features_dir, f'{feature_name}.pt') for feature_name in batch_feature_names]

            # Process and save features for each image in the batch
            for idx in range(batch_features.shape[0]):
                image_path = batch_image_paths[idx]
                feature_path = batch_feature_paths[idx]
                number = batch_crop_numbers[idx]
                feature_image_map[feature_path] = (image_path, number)

                features = batch_features[idx]
                features_reshaped = features.clone()
                torch.save(features_reshaped, feature_path)

    with open(features_json, 'w') as config:
        json.dump(feature_image_map, config)