import csv
import os
import base64
import requests
import json
import argparse
import time
import pickle
import traceback
import sys

csv.field_size_limit(sys.maxsize)


def encode_image_to_base64(img_path: str):
    """Encodes an image file to a base64 string."""
    if not os.path.exists(img_path):
        raise FileNotFoundError(f"Image file not found: {img_path}")
    with open(img_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def clean_context_text(text: str):
    lines = text.splitlines()
    prefixes = {
        "answer",
        "question",
        "want",
        "?",
        "okay",
        "would you",
        "let me",
        "let's",
        "do you",
        "breakdown",
        "Here's",
        "image 3",
        "image 4",
        "image 5",
        "image 6",
        "image 7",
        "image 8",
        "image 9",
        "in short",
        "in essence",
    }
    processed_lines = []
    for line in lines:
        normalized_line = line.strip().replace("**", "").lower()
        should_remove = False
        for prefix in prefixes:
            if prefix in normalized_line:
                should_remove = True
                break

        if not should_remove:
            processed_lines.append(line)  # Add the original, non-normalized line

    return "\n".join(processed_lines)


def get_vlm_answer(
    image_path: str,
    question_from_csv: str,
    total_pred_from_csv: str,
    original_prediction_from_csv: str,
    api_key: str,
    model_id: str,
    fewshot: dict = None,
    without_image: bool = False,
):
    """
    Encode image, build a single combined content list containing few-shot demos
    followed by the main example, send to VLM API, return answer.
    """
    content_list = []

    # content_list.append(
    #     {
    #         "type": "text",
    #         "text": f"""Answer the last question with given image.
    #         Directly output the answer to the question.
    #         Considering Context, provide a best answer.
    #         There is an 4 examples below:

    #         """,
    #     }
    # )

    # content_list.append(
    #     {
    #         "type": "text",
    #         "text": f"""Answer the encyclopedic question about given image.
    #         Directly output the answer to the question according to context.
    #         Considering Context, provide a best answer.
    #         There is an 2 examples below:

    #         """,
    #     }
    # )

    content_list.append(
        {
            "type": "text",
            "text": f"""Answer the knowledge-intensive question based on the provided image and context.
            Generate a concise and accurate answer grounded in the retrieved information.
            Use the context to support reasoning, and directly output the final answer.

            There are 3 examples are shown below:
        
            """,
        }
    )

    # total_pred_from_csv = clean_context_text(total_pred_from_csv)

    # content_list.append(
    #     {
    #         "type": "text",
    #         "text": f"""Answer the following questions based on the image and your knowledge.
    #         There is an 3 examples below:

    #         """,
    #     }
    # )

    # Build few-shot demonstration in one content list
    if fewshot:
        reasoning_record = ""
        for i, (fs_img_path, fs_q, fs_ans) in enumerate(
            zip(
                fewshot["image_paths"][1:],
                fewshot["questions"][1:],
                fewshot["answers"][1:],
            )
        ):
            try:

                content_list.append(
                    {
                        "type": "text",
                        "text": f"##Example {i+1}:\n",
                    }
                )
                fs_b64 = encode_image_to_base64(fs_img_path)
                fs_struct = {"url": f"data:image/jpeg;base64,{fs_b64}"}
                if not without_image:  # Few-shot images are used
                    content_list.append({"type": "image_url", "image_url": fs_struct})
                if "reasoning_record" in fewshot:
                    reasoning_record = (
                        "##Context: " + fewshot["reasoning_record"][i] + "\n\n"
                    )
                content_list.append(
                    {
                        "type": "text",
                        "text": f"{reasoning_record}##Question: {fs_q}\n##Best Answer: {fs_ans}\n\n",
                    }
                )
            except Exception as e:
                print(f"Warning: Skipping few-shot image '{fs_img_path}': {e}")

    # Main example appended to the same content list
    content_list.append(
        {
            "type": "text",
            "text": f"\nNow, answer this question\n",
        }
    )
    img_b64 = encode_image_to_base64(image_path)
    main_struct = {"url": f"data:image/jpeg;base64,{img_b64}"}
    content_list.append({"type": "image_url", "image_url": main_struct})
    content_list.append(
        {
            "type": "text",
            "text": f"##Context: {total_pred_from_csv}\n##Question: {question_from_csv}\n##Best Answer: ",
        }
    )
    # content_list.append(
    #     {
    #         "type": "text",
    #         "text": f"##Context: {original_prediction_from_csv}\n\n##Question: {question_from_csv}\n##Best Answer: ",
    #     }
    # )

    messages = [{"role": "user", "content": content_list}]

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}",
    }
    payload = {"model": model_id, "messages": messages, "max_tokens": 256}

    try:
        resp = requests.post(
            "https://openrouter.ai/api/v1/chat/completions",
            headers=headers,
            data=json.dumps(payload),
            timeout=90,
        )
        resp.raise_for_status()
        data = resp.json()
        if "error" in data:
            sys.exit(f"API error: {data['error'].get('message', data['error'])}")
        ans = data.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
        return ans or "Error: Empty response"
    except Exception as e:
        sys.exit(f"API error: {data['error'].get('message', data['error'])}")


def save_intermediate_results(path: str, results: list):
    try:
        with open(path, "wb") as pf:
            pickle.dump(results, pf)
    except Exception as e_save:
        print(f"Warning: Could not save intermediate results: {e_save}")


def process_csv_resumable(
    input_csv_path: str,
    api_key: str,
    model_id: str,
    fewshot_data: list = None,
    without_image: bool = False,
):

    base, _ = os.path.splitext(input_csv_path)
    pkl = f"{base}_updated_clean.pkl"
    out_csv = f"{base}_updated_clean.csv"
    if fewshot_data is not None:
        pkl = f"{base}_updated_fewshot_clean.pkl"
        out_csv = f"{base}_updated_fewshot_clean.csv"
        if "reasoning_record" in fewshot_data[0]:
            pkl = f"{base}_updated_fewshot_with_rr_clean.pkl"
            out_csv = f"{base}_updated_fewshot_with_rr_clean.csv"
    if without_image:
        pkl = pkl.replace("_updated", "_updated_wo_image")
        out_csv = out_csv.replace("_updated", "_updated_wo_image")
    model_name = model_id.replace("google/", "").replace("openai/", "")
    if not model_name in input_csv_path:
        pkl = pkl.replace("_updated", f"_updated_with_{model_name}")
        out_csv = out_csv.replace("_updated", f"_updated_with_{model_name}")

    fewshots = fewshot_data or []
    rows = []
    start = 0
    if os.path.exists(pkl):
        try:
            with open(pkl, "rb") as pf:
                rows = pickle.load(pf)
            start = len(rows)
            print(f"Resuming from {start} rows...")
        except:
            rows = []
            start = 0

    with open(input_csv_path, "r", encoding="utf-8") as infile:
        reader = csv.DictReader(infile)
        if not reader.fieldnames:
            print("Error: CSV empty or missing header.")
            save_intermediate_results(pkl, rows)
            return
        fields = reader.fieldnames

        for idx, row in enumerate(reader):
            if idx < start:
                continue
            img = row.get("image_path")
            q = row.get("question")
            tp = row.get("total_pred", "N/A")
            op = row.get("prediction", "N/A")
            print(f"Line {idx+1}: {img}")
            if not img or not q:
                res = "Error: Missing image_path or question"
            else:
                fs = fewshots[idx] if idx < len(fewshots) else None
                res = get_vlm_answer(
                    img, q, tp, op, api_key, model_id, fs, without_image
                )
            row["prediction"] = res
            rows.append(row)
            if (idx + 1) % 10 == 0:
                save_intermediate_results(pkl, rows)
                print(f"Saved checkpoint at line {idx+1}")

    save_intermediate_results(pkl, rows)
    with open(out_csv, "w", encoding="utf-8", newline="") as outf:
        writer = csv.DictWriter(outf, fieldnames=fields)
        writer.writeheader()
        writer.writerows(rows)
    print(f"Finished! Output: {out_csv}")


def main():
    parser = argparse.ArgumentParser(
        description="CSV to VLM answers with single-list few-shot demos"
    )
    parser.add_argument(
        "--csv", default="", dest="input_csv_path", help="Input CSV file path."
    )
    parser.add_argument(
        "--pkl", default="", dest="input_pkl_path", help="Input PKL file path."
    )
    parser.add_argument(
        "--api_key", help="OpenRouter API key or set OPENROUTER_API_KEY env var."
    )
    parser.add_argument(
        "--model_id", required=True, help="VLM model ID (e.g., openai/gpt-4o)."
    )
    parser.add_argument(
        "--fewshot_jsonl", help="Few-shot JSONL aligned with CSV (optional)."
    )
    parser.add_argument(
        "--without-image",
        action="store_true",
        help="Only process text without images ",
    )
    args = parser.parse_args()

    api = args.api_key or os.environ.get("OPENROUTER_API_KEY")
    if not api:
        print("Error: API key required.")
        sys.exit(1)

    fs_list = []
    if args.fewshot_jsonl:
        try:
            with open(args.fewshot_jsonl, "r", encoding="utf-8") as jf:
                fs_list = [json.loads(line) for line in jf if line.strip()]
            print(f"Loaded {len(fs_list)} few-shot entries.")
        except Exception as e:
            print(f"Warning: Failed to load few-shot JSONL: {e}")

    process_csv_resumable(
        args.input_csv_path,
        api,
        args.model_id,
        fs_list,
        args.without_image,
    )


if __name__ == "__main__":
    main()
