import sys
import csv
import os
import base64
import json
import argparse
import time
import pickle
import sys
from preprocess_answer2 import clean_context_text

# Add LLaVA and PyTorch related imports
import torch
from PIL import Image
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)
from llava.utils import disable_torch_init

csv.field_size_limit(sys.maxsize)


IMAGE_TOKEN = f"{DEFAULT_IMAGE_TOKEN}\n\n"
CONTEXT_PROMPT = "# {context}\n"
RET_TOKEN = 128251
REL_TOKEN = 128253

# Define a placeholder string for images in prompts
IMAGE_PLACEHOLDER = DEFAULT_IMAGE_TOKEN + "\n"

csv.field_size_limit(sys.maxsize)


def get_llava_vlm_answer(
    image_path: str,
    question_from_csv: str,
    context_from_csv: str,  # This will be original_prediction_from_csv
    tokenizer,
    model,
    image_processor,
    conv_mode: str,
    model_config,  # model.config, used by process_images
    temperature: float = 0.2,
    top_p: float = None,  # Consistent with release_retrieval.py defaults
    num_beams: int = 1,
    max_new_tokens: int = 128,
):
    """
    Gets an answer from a local LLaVA model using a single image, question, and context.
    This aims to replicate the core inference style of release_retrieval.py for a single generation pass.
    The few-shot prompting from the original OpenRouter-based get_vlm_answer is not used here.
    """

    context_from_csv = clean_context_text(context_from_csv)

    device = model.device
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], IMAGE_TOKEN + f"{question_from_csv}")
    conv.append_message(conv.roles[1], "[Retrieval]")
    conv.append_message(conv.roles[0], f"Consider this paragraph: <paragraph> {context_from_csv} </paragraph>. Give a short answer.")
    conv.append_message(conv.roles[1], None)
    prompt_for_tokenizer = conv.get_prompt()
    input_ids = (
        tokenizer_image_token(
            prompt_for_tokenizer, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
        )
        .unsqueeze(0)
        .to(device)
    )
    pil_img = Image.open(image_path).convert("RGB")
    image_tensor = process_images([pil_img], image_processor, model_config)[0]
    image_tensor = (
        image_tensor.to(dtype=torch.float16, device="cuda", non_blocking=True),
    )[0].unsqueeze(0)

    try:
        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=image_tensor,  # Pass None if no image
                image_sizes=pil_img.size,  # Pass None if no image
                do_sample=True if temperature > 0 else False,
                temperature=temperature,
                top_p=top_p,
                num_beams=num_beams,
                max_new_tokens=max_new_tokens,
                use_cache=True,
                pad_token_id=tokenizer.pad_token_id
                or tokenizer.eos_token_id,  # Ensure pad_token_id is set
            )

        #input_token_len = input_ids.shape[1]
        #generated_ids = output_ids[:, input_token_len:]
        #print(output_ids)
        answer = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
        #print(answer)
        return answer.strip()

    except Exception as e:
        print(f"Error during LLaVA model inference: {e}")
        return f"Error: Inference failed"


def save_intermediate_results(path: str, results: list):
    try:
        with open(path, "wb") as pf:
            pickle.dump(results, pf)
    except Exception as e_save:
        print(f"Warning: Could not save intermediate results: {e_save}")


def process_csv_resumable(
    input_csv_path: str,
):
    model_path = "aimagelab/ReflectiVA"
    model_path = os.path.expanduser(model_path)
    model_name = "llava_llama_3.1"
    # model_name = get_model_name_from_path(model_path)  # Ensure model name is set
    conv_mode = "llama_3_1"
    tokenizer, model, image_processor, _ = load_pretrained_model(
        model_path, None, model_name
    )
    base, _ = os.path.splitext(input_csv_path)
    pkl = f"{base}_reflectiva.pkl"
    out_csv = f"{base}_reflectiva.csv"

    rows = []
    start = 0
    if os.path.exists(pkl):
        try:
            with open(pkl, "rb") as pf:
                rows = pickle.load(pf)
            start = len(rows)
            print(f"Resuming from {start} rows...")
        except:
            rows = []
            start = 0

    with open(input_csv_path, "r", encoding="utf-8") as infile:
        reader = csv.DictReader(infile)
        if not reader.fieldnames:
            print("Error: CSV empty or missing header.")
            save_intermediate_results(pkl, rows)
            return
        fields = reader.fieldnames

        for idx, row in enumerate(reader):
            if idx < start:
                continue
            img = row.get("image_path")
            q = row.get("question")
            tp = row.get("total_pred", "N/A")
            op = row.get("prediction", "N/A")
            print(f"Line {idx+1}: {img}")
            res = get_llava_vlm_answer(
                img,
                q,
                tp,  # Pass the context string
                tokenizer,
                model,
                image_processor,
                conv_mode,
                model.config,
            )
            row["prediction"] = res
            rows.append(row)
            if (idx + 1) % 10 == 0:
                save_intermediate_results(pkl, rows)
                print(f"Saved checkpoint at line {idx+1}")

    save_intermediate_results(pkl, rows)
    with open(out_csv, "w", encoding="utf-8", newline="") as outf:
        writer = csv.DictWriter(outf, fieldnames=fields)
        writer.writeheader()
        writer.writerows(rows)
    print(f"Finished! Output: {out_csv}")


def main():
    parser = argparse.ArgumentParser(
        description="CSV to VLM answers with single-list few-shot demos"
    )
    parser.add_argument(
        "--csv", default="", dest="input_csv_path", help="Input CSV file path."
    )
    parser.add_argument(
        "--pkl", default="", dest="input_pkl_path", help="Input PKL file path."
    )
    args = parser.parse_args()

    process_csv_resumable(
        args.input_csv_path,
    )


if __name__ == "__main__":
    main()
