import argparse
import torch

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
import transformers
from dataclasses import dataclass, field
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
import os
from tqdm import tqdm
import pandas as pd
import re
def parse_args():
    parser = argparse.ArgumentParser('')
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    # parser.add_argument("--image-file", type=str, required=True)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--num_workers", type=int, default=4)
    # cc rewrite
    parser.add_argument('--data_path', type=str, help='path data files',
                        default='./cc_data/train/00000/')
    parser.add_argument('--cc_data_path', type=str, help='path data files',
                        default='./cc_replace/Train_GCC-training_output.csv')
    parser.add_argument('--batch_size', type=int, help='the number of data each epoch deal',
                        default=3)
    parser.add_argument('--save_path', type=str, help='path data files',
                        default='./cc_rewrite/')
    args = parser.parse_args()
    return args
args = parse_args()

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

def read_txt_file(path):
    with open(path, "r") as f:
        txt = f.read()
    return txt

def clean_word(word):
    cleaned_word = word.rstrip(".")
    return cleaned_word.strip()

def image_encoder(model, image_processor, filepaths):
    images = load_images(filepaths)
    image_sizes = [x.size for x in images]
    # images_tensor = process_images(
    #     images,
    #     image_processor,
    #     model.config
    # ).to(model.device, dtype=torch.float16)
    batch_image_tensor = image_processor(images, return_tensors="pt")["pixel_values"].to(model.device, dtype=torch.float16)
    return batch_image_tensor, image_sizes

def pad_sequence_to_max_length(sequence, max_length, padding_value=0):
    """Pad a sequence to the desired max length."""
    if len(sequence) >= max_length:
        return sequence
    return torch.cat([torch.full((max_length - len(sequence),), padding_value, dtype=sequence.dtype), sequence])

def llava_output(model, images_tensor, input_ids, tokenizer):
    print(input_ids, images_tensor)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids.to(device='cuda', non_blocking=True),
            images=images_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            use_cache=True)
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    return outputs

def load_original_cc(data_path):
    '''
    load the original cc in each file
    :param data_path: the original datafile
    :return: original cc data in dataframe file style
    '''
    original_cc = []
    for root, dirs, files in tqdm(os.walk(data_path)):
        for file in tqdm(files):
            if file[-4:] == ".jpg" and file[:2] != "._":
                filepath, txt_file_path = os.path.join(root, file), os.path.join(root, file[:-4]+".txt")
                title = read_txt_file(txt_file_path)
                # print(title, "\t", title_replace)
                # print(filepath, title)
                # print(prompt_rewrite % clean_word(title))
                original_cc.append({"filepath":filepath, "title":title})
    original_cc_df = pd.DataFrame(original_cc, columns=["filepath", "title"])
    return original_cc_df

def load_predeal_cc(data_path):
    '''
    load the pre-deal cc dataset
    :param data_path: origianl_cc_data_path
    :return: origianl_cc_df: dataframe style file of original cc data
    '''
    origianl_cc_df = pd.read_csv(data_path, sep="\t")
    return origianl_cc_df

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, questions, image_paths, tokenizer, image_processor, model_config):
        self.questions = questions
        self.image_folder = image_paths
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model_config = model_config

    def __getitem__(self, index):
        qs = self.questions[index]
        image_file = self.image_folder[index]
        if self.model_config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        image = Image.open(image_file).convert('RGB')
        image_tensor = process_images([image], self.image_processor, self.model_config)[0]

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

        return index, input_ids, image_tensor

    def __len__(self):
        return len(self.questions)

@dataclass
class DataCollatorForVisualTextGeneration(object):
    tokenizer: transformers.PreTrainedTokenizer

    def pad_sequence(self, input_ids, batch_first, padding_value):
        if self.tokenizer.padding_side == "left":
            input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=batch_first,
            padding_value=padding_value)
        if self.tokenizer.padding_side == "left":
            input_ids = torch.flip(input_ids, [1])
        return input_ids

    def __call__(self,
                 batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
        indices, input_ids, images = zip(*batch)
        input_ids = self.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        images = torch.stack(images, dim=0)
        return indices, input_ids, images

# DataLoader
def create_data_loader(questions, image_paths, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
    dataset = CustomDataset(questions, image_paths, tokenizer, image_processor, model_config)
    collator = DataCollatorForVisualTextGeneration(tokenizer=tokenizer)
    data_loader = DataLoader(dataset, collate_fn=collator, batch_size=batch_size, num_workers=num_workers, shuffle=False)
    return data_loader


def rewrite_csv_files_multi(model, image_processor, tokenizer, path, save_path):
    # data_types = ['train', 'val']
    # for data_type in data_types:
    #     data_path = os.path.join(path, data_type)
    data_type = 'train'
    data_path = os.path.join(path, data_type)
    save_file_path = os.path.join(save_path, "Train_GCC-training_output.csv") if data_type == "train" \
        else os.path.join(save_path, "Validation_GCC-1.1.0-Validation_output.csv")

    csv_list = []

    # llava_prompt
    # v1.3
    prompt_rewrite = "You are a powerful image captioner. " \
                     "Rewrite the original caption like human description in:\n" \
                     "1. Describing object's appearance, position, and the relationship or interaction between them, " \
                     "2. Describe what is in the background and what is in the foreground.\n" \
                     "3. Focusing on colors, styles, materials, and any notable accessories.\n" \
                     "4. Identify the domain of the image (e.g., sculpture, cartoon, origami).\n" \
                     "The original caption ==> %s." \
                     "Do not describe the contents by itemizing them in list form. " \
                     "Minimize aesthetic descriptions as much as possible. " \
                     "Please provide a answer without starting with 'The image,' in 50 words or fewer ==> [Insert text here]"

    prompt_anlaysis = "You are a powerful human manipulate intent analyst. " \
                     "Infer human manipulate intents in an image caption.\n" \
                     "The image caption ==> %s.\n" \
                     "Minimize aesthetic descriptions as much as possible. " \
                     "Please summarize the main points of the answer without starting with 'The image,' in 16 words or fewer ==> [Insert text here]"


    # pre-deal original image and captions
    print("Loading original image and captions...")
    # original_cc_df = load_original_cc(data_path) # if do not have pre-deal dataset
    original_cc_df = load_predeal_cc(args.cc_data_path) # if have the pre-deal cc data in dataframe file style
    # print(len(original_cc_df))
    batch_size, num_rows = args.batch_size, len(original_cc_df)

    # load checkpoint
    if os.path.exists(save_file_path):
        processed_data = pd.read_csv(save_file_path)
        # Determine the number of rows processed
        processed_rows = processed_data.shape[0]
    else:
        # Initialize an empty DataFrame to store processed data
        processed_data = pd.DataFrame()
        processed_rows = 0

    current_row = processed_rows
    epoch = processed_rows // batch_size

    while current_row < num_rows:
        start_row = current_row
        end_row = min(start_row + batch_size, num_rows)
        epoch_data = original_cc_df.iloc[start_row:end_row]

        # deal with epoch data
        filepaths, titles = epoch_data['filepath'].tolist(), epoch_data['title'].tolist()
        questions = [prompt_rewrite % clean_word(title) for title in titles]
        data_loader = create_data_loader(
            questions,
            filepaths,
            tokenizer,
            image_processor,
            model.config,
            batch_size=batch_size,
            num_workers=args.num_workers,
        )

        # # encode image
        # images_tensor, image_sizes = image_encoder(model, image_processor, filepaths)

        for indices, input_ids, image_tensor in tqdm(data_loader):
            title_rewrite = llava_output(model, image_tensor, input_ids, tokenizer)

        # rewrite caption
        print(title_rewrite)

        title_analysis = llava_output(model, images_tensor, image_sizes, tokenizer,
                                     prompt_anlaysis % clean_word(title_rewrite))
        print(title_analysis)

        csv_list.append({"filepath":filepath, "title":title, "title_rewrite":title_rewrite, "summary":title_analysis})


                # print(img_path)
                # print(len(csv_list))
    # print(len(csv_list))
    # # save
    df = pd.DataFrame(csv_list, columns=["filepath", "title", "title_replace"])
    print(df)
    # df.to_csv(save_file_path, index=False, sep="\t")
    # print("Saved to", save_file_path)
    print("Done!")

def main(args):
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        args.model_path, args.model_base, model_name
    )
    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
    else:
        args.conv_mode = conv_mode

    rewrite_csv_files_multi(model, image_processor, tokenizer, args.data_path, args.save_path)


if __name__ == "__main__":
    main(args)
