import os
import sys
import json
import uuid
import random
import time
from textwrap import indent
import openai
from openai import AsyncOpenAI
import argparse
import logging
import base64
from dotenv import load_dotenv
from PIL import Image
import io
import glob
from pathlib import Path

from prompt import (
    read_img_prompt,
    image_improvement_prompt,
    image_verifier_prompt,
    initial_solution_prompt,
    self_improvement_prompt,
    verifier_physics_prompt,
    verify_general_prompt,
    correction_prompt,
    physics_precheck_reminder,
    verification_remider
)

# Configuration
MODEL_NAME = None
API_BASE_URL = None
API_KEY = None
MODEL_REGISTRY = {
    "Qwen/Qwen2.5-VL-32B-Instruct": {
        "model_id": "Qwen/Qwen2.5-VL-32B-Instruct",  
        "api_key": "your_api_key",
        "base_url": "your_url"
    },
    "gemini-2.5-flash-thinking": {
        "model_id": "gemini-2.5-flash-thinking", 
        "api_key": "your_api_key",
        "base_url": "your_url"
    },
    "intern-s1": {
            "model_id": "intern-s1",
            "api_key": "your_api_key",
            "base_url": "your_url"
    }
}

# Global variables for logging
_log_file = None
original_print = print

def log_print(*args, **kwargs):
    """Custom print function that writes to both stdout and log file."""
    original_print(*args, **kwargs)
    if _log_file is not None:
        message = ' '.join(str(arg) for arg in args)
        _log_file.write(message + '\n')
        _log_file.flush()

print = log_print

def set_log_file(log_file_path):
    """Set the log file for output."""
    global _log_file
    if log_file_path:
        try:
            _log_file = open(log_file_path, 'w', encoding='utf-8')
            return True
        except Exception as e:
            print(f"Error opening log file {log_file_path}: {e}")
            return False
    return True

def close_log_file():
    """Close the log file if it's open."""
    global _log_file
    if _log_file is not None:
        _log_file.close()
        _log_file = None

def read_img_agent(client, image_parts, model_name):
    """
    Consumes ONLY image parts, returns a pure text summary.
    """
    if not image_parts:
        return ""

    system_msg = {
        "role": "system",
        "content": "You are a careful scientific figure-reading assistant. "
                   "Extract facts only from the given images and the provided instructions."
    }

    user_content = [{"type": "text", "text": read_img_prompt}]
    user_content.extend(image_parts)  

    messages = [system_msg, {"role": "user", "content": user_content}]

    try:
        resp = client.chat.completions.create(
            model=model_name,
            messages=messages,
            temperature=0.6,
            stream=False
        )
        return (resp.choices[0].message.content or "").strip()
    except Exception as e:
        print(f"[read_img_agent] Error: {e}")
        return ""


def get_openai_client(API_KEY, API_BASE_URL):
    if not API_KEY:
        print("Error: API_KEY。")
        sys.exit(1)

    if not API_BASE_URL:
        print("Error: API_BASE_URL 。")
        sys.exit(1)
    try:
        return openai.OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
    except Exception as e:
        print(f"Error creating OpenAI client: {e}")
        sys.exit(1)


# Breakpoint Resumption
def _should_skip_by_log(problem_json_path: str, out_dir: Path) -> (bool, Path):
    """
Derive the log file name as <stem>.log from the title JSON file name (e.g., IPhO_2025_2_C_1.log). 
If the log file already exists in the out_dir, return (True, log file path) to indicate that it 
should be skipped; otherwise, return (False, expected log file path).
    """
    problem_path = Path(problem_json_path)
    base_stem = problem_path.stem
    log_path = out_dir / f"{base_stem}.log"
    return log_path.exists(), log_path

def improvement_image(client, image_parts, image_info, model_name, error_info=None):
    """
    Dedicated image self-improvement pass using the global `image_improvement_prompt`.
    Consumes ONLY image parts, returns a pure text summary.
    """
    if not image_parts:
        return ""

    image_history = [
        {"role": "user", "content": image_parts},
        {"role": "assistant", "content": image_info}
    ]

    messages = build_messages(read_img_prompt, image_improvement_prompt, other_prompts=error_info, previous_messages=image_history)

    try:
        resp = client.chat.completions.create(
            model=model_name,
            messages=messages,
            temperature=0.6,
            stream=False
        )
        image_newinfo = (resp.choices[0].message.content or "").strip()
        return image_newinfo
    except Exception as e:
        print(f"[image_selfimprovement] Error: {e}")
        return ""

def verify_image(client, image_parts, image_newinfo, model_name):
    """
    Dedicated image verification pass using the global `image_verifier_prompt`.
    Consumes ONLY image parts, returns a pure text summary.
    """
    if not image_parts:
        return ""

    image_summary = f"### Image ###\n\n{image_parts}\n\n### information ###\n\n{image_newinfo}"
    messages = build_messages(image_verifier_prompt, image_summary)

    try:
        resp = client.chat.completions.create(
            model=model_name,
            messages=messages,
            temperature=0.6,
            stream=False
        )
        verification_summary = (resp.choices[0].message.content or "").strip()
        print(f">>>>>>>Verification Summary: {verification_summary}")
        bug_report = ""
        if "yes" not in verification_summary.lower():
            bug_report = extract_detailed_solution(verification_summary, "Detailed Verification", False)
        return bug_report
    except Exception as e:
        print(f"[image_verifier] Error: {e}")
        return ""

def init_image_exploration(client, image_parts, model_name):
    """
    First exploration for image understanding:
    1. Read image
    2. Improve once
    3. Verify once
    Return conversation history, improved info, verify result, verify flag
    """
    if not image_parts:
        return None, None

    # Step 1: Initial read
    image_info = read_img_agent(client, image_parts, model_name)
    if not image_info:
        return None, None

    # Step 2: Self improvement
    image_newinfo = improvement_image(client, image_parts, image_info, model_name)

    # Step 3: Verification
    bug_report = verify_image(client, image_parts, image_newinfo, model_name)

    return image_newinfo, bug_report

def final_image_agent(client, image_parts, model_name, consecutive_verify, max_rounds=10):
    """
    Multi-round image refinement agent.
    """
    image_newinfo, bug_report = init_image_exploration(client, image_parts, model_name)

    error_count = 0
    correct_count = 1

    for i in range(max_rounds):
        print(f"[agent_image] Round {i}, correct_count={correct_count}, error_count={error_count}")

        # If verification fails
        if bug_report:
            error_count += 1
            correct_count = 0

            print(">>>>>> Verification failed, improving ...")
            image_newinfo2 = improvement_image(client, image_parts, image_newinfo, model_name, error_info=bug_report)

            # Re-verify
            bug_report = verify_image(client, image_parts, image_newinfo2, model_name)
        else:
            # Verification passed
            correct_count += 1
            error_count = 0
            print(">>>>>> Verification passed.")

        # Success condition
        if correct_count >= consecutive_verify:
            print(">>>>>> Found correct image interpretation.")
            return image_newinfo

        # Failure condition
        if error_count >= consecutive_verify:
            print(">>>>>> Failed more than 3 times, returning last failed result.")
            return image_newinfo

    print(">>>>>> Reached max rounds, returning last failed result.")
    return image_newinfo


def read_file_content(filepath):
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            return f.read()
    except FileNotFoundError:
        print(f"Error: File not found at '{filepath}'")
        sys.exit(1)
    except Exception as e:
        print(f"Error reading file '{filepath}': {e}")
        sys.exit(1)

def compress_image(img_path, max_size=(800, 800), quality=85):
    """Compress an image to reduce its size while maintaining quality.
    """
    img = Image.open(img_path)
    img = img.convert("RGB")
    img.thumbnail(max_size)
    output = io.BytesIO()
    img.save(output, format="JPEG", optimize=True, quality=quality)
    return output.getvalue()


def process_single_problem_file(problem_json_path, model_name, other_prompts, max_runs, consecutive_verify, log_dir=None):
    global MODEL_NAME, API_KEY, API_BASE_URL

    cfg = MODEL_REGISTRY.get(model_name)
    if not cfg:
        raise RuntimeError(
            f"Failed finding configuration of '{model_name}' in MODEL_REGISTRY"
        )

    MODEL_NAME = cfg["model_id"]     
    API_KEY = cfg["api_key"]
    API_BASE_URL = cfg["base_url"]
        
    problem_path = Path(problem_json_path)
    log_dir = Path(log_dir) if log_dir else problem_path.parent
    log_dir.mkdir(parents=True, exist_ok=True)


    content_raw = read_file_content(str(problem_json_path))
    data_raw = json.loads(content_raw)

    if not isinstance(data_raw, list):
        print(f"[ERROR] Expected array in {problem_json_path}, got {type(data_raw)}")
        return False, None

    all_ok = True
    for idx, prob in enumerate(data_raw):
        prob_id = (prob.get("id") or f"{problem_path.stem}_{idx}").strip()
        log_file_path = log_dir / f"{prob_id}.log"
        solution_file_path = log_dir / f"final_solution_{prob_id}.txt"

        if log_file_path.exists():
            print(f"[SKIP] {prob_id}: existing log -> {log_file_path.name}")
            continue

        if not set_log_file(str(log_file_path)):
            continue
        print(f"Logging to file: {log_file_path}")
    
        image_paths = prob.get("image_question", []) or []
        image_parts = []
        if image_paths:
            image_parts.append({"type": "text","text": "### Image ###"})
        for img_rel in image_paths:
            try:
                img_abs = (problem_path.parent / img_rel).resolve()
                compressed_bytes = compress_image(str(img_abs), max_size=(800, 800), quality=85)
                img_b64 = base64.b64encode(compressed_bytes).decode("utf-8")
                image_parts.append({
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{img_b64}",
                        "detail": "low"
                    }
                })
            except Exception as e:
                print(f"Error processing image '{img_rel}': {e}")

       
        client = get_openai_client(API_KEY, API_BASE_URL)
        image_readout = final_image_agent(client, image_parts, MODEL_NAME, consecutive_verify, max_rounds=10)
        print(">>>>>>> Image reading summary:")
        print(image_readout)

        problem_statement = build_problem_statement(prob, image_readout=image_readout)
        other_prompts_list = other_prompts.split(',') if other_prompts else []

        success = False
        for i in range(max_runs):
            print(f"\n\n>>>>>>>>>>>>>>>>>>>>>>>>>> Run {i+1} of {max_runs} ...")
            try:
                sol = agent(problem_statement, consecutive_verify, other_prompts_list)
                if sol is not None:
                    print(f">>>>>>> Found a correct solution in run {i+1}.")
                    with open(solution_file_path, "w", encoding='utf-8') as f:
                        f.write(sol)
                    print(f"Final solution saved to {solution_file_path}")
                    success = True
                    break
            except Exception as e:
                print(f">>>>>>> Error in run {i+1}: {e}")
                import traceback
                traceback.print_exc()
                continue

        close_log_file()
        if not success:
            all_ok = False
            print(f"[FAILED] {prob_id}")
        else:
            print(f"[DONE] {prob_id} -> {solution_file_path}")

    return all_ok, None

def build_problem_statement(problem_obj: dict, image_readout: str = ""):
    text_lines = []
    text_lines.append(f"#id#: {problem_obj.get('id', '')}")
    text_lines.append(f"#context#: {problem_obj.get('context', '')}")
    text_lines.append(f"#question#: {problem_obj.get('question', '')}")

    if image_readout:
        text_lines.append(f"### Figure Reading (verbatim) ###\n{image_readout}")

    text_content = "\n\n".join(text_lines)
    return [{"type": "text", "text": text_content}]
       
def build_messages(system_prompt, question_prompt, other_prompts=None, previous_messages=None):
    """Builds the messages list for the API call."""
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    if previous_messages:
        messages.extend(previous_messages)
    messages.append({"role": "user", "content": question_prompt})
    if other_prompts:
        for prompt in other_prompts:
            messages.append({"role": "user", "content": prompt})
    return messages

def send_api_request(client, model, messages):
    """Sends the request to the Gemini API using the OpenAI SDK and streams the response."""
    try:
        compatible_model_name = model
        stream = client.chat.completions.create(
            model=compatible_model_name,
            messages=messages,
            temperature=0.6,
            # top_p=0.95,
            max_tokens=32768,
            stream=True
        )
        
        response_parts = []
        in_thought = False
        for chunk in stream:
            if not chunk.choices:
                continue
            
            delta = chunk.choices[0].delta
            
            # Robustly check for the start of a thought block
            extra_content = getattr(delta, 'extra_content', None)
            if extra_content and extra_content.get('google', {}).get('thought'):
                in_thought = True

            # Process content only if we are NOT inside a thought block
            if delta.content and not in_thought:
                response_parts.append(delta.content)

            # Check for end of a thought block
            if in_thought and delta.content and '</thought>' in delta.content:
                # Content might start right after the closing tag in the same chunk
                _, _, after_thought = delta.content.partition('</thought>')
                if after_thought:
                    response_parts.append(after_thought)
                in_thought = False

        return "".join(response_parts)
    except openai.APIError as e:
        time.sleep(60)
        print(f"Error during API request: {e}")
        if e.response:
             print(f"Raw API Response (if available): {e.response.text}")
        sys.exit(1)

def extract_detailed_solution(solution, marker='Detailed Solution', after=True):
    """Extracts the text after a marker from the solution string."""
    idx = solution.find(marker)
    if idx == -1: return ''
    return solution[idx + len(marker):].strip() if after else solution[:idx].strip()

def run_two_stage_verification(client, problem_statement, solution, physics_verify_model_name, general_verify_model_name, verbose=True):
    """
    Stage 1: Physics precheck (quick, physics hygiene).
    - If FAIL: return bug_report from precheck, mark overall as not good, and SKIP general verify.
    - If PASS: run the existing general verification

    """
    dsol = extract_detailed_solution(solution)
    newst = f"### Problem ###\n\n{problem_statement}\n\n### Solution ###\n\n{dsol}\n\n{physics_precheck_reminder}"

    if verbose:
        print(">>>>>>> Start physics precheck.")
    pre_messages = build_messages(verifier_physics_prompt, newst)
    if verbose:
        print(f">>>>>>> Physics precheck messages:\n{json.dumps(pre_messages, indent=4)}")
    pre_out = send_api_request(client, physics_verify_model_name, pre_messages)
    if verbose:
        print(f">>>>>>> Physics precheck result:\n{json.dumps(pre_out, indent=4)}")

    # Judge PASS / FAIL（first line）
    verdict_line = pre_out.strip().splitlines()[0] if pre_out else ""
    precheck_passed = ("PASS" in verdict_line.upper())

    if not precheck_passed:
        # Output the bug report directly: start from the "Bug Report" section; 
        # if the format is not followed, it will degrade to the full text.
        bug_report = ""
        if "Bug Report" in pre_out:
            bug_report = pre_out.split("Bug Report", 1)[1].strip().lstrip(":").strip()
        else:
            bug_report = pre_out 
        if verbose:
            print(">>>>>>> Physics precheck FAIL -> skip general verify.")
            print(f">>>>>>> Bug report (from precheck):\n{json.dumps(bug_report, indent=4)}")
        return bug_report, "no", False

    # Stage 2: general verify
    if verbose:
        print(">>>>>>> Physics precheck PASS -> running general verify.")
    bug_report, good_verify_text = verify_solution(
        client,
        problem_statement=problem_statement,
        solution=solution,
        verify_model_name=general_verify_model_name,
        verbose=verbose
    )
    return bug_report, good_verify_text, True

def verify_solution(client, problem_statement, solution, verify_model_name, verbose=True):
    print(verify_model_name)
    dsol = extract_detailed_solution(solution)
    newst = f"### Problem ###\n\n{problem_statement}\n\n### Solution ###\n\n{dsol}\n\n{verification_remider}"
    if verbose: print(">>>>>>> Start verification.")
    messages = build_messages(verify_general_prompt, newst)
    if verbose: print(f">>>>>>> Verification prompt messages:\n{json.dumps(messages, indent=4)}")
    out = send_api_request(client, verify_model_name, messages)
    if verbose: print(f">>>>>>> Verification results:\n{json.dumps(out, indent=4)}")
    
    check_correctness_prompt = f'Response in "yes" or "no". Is the following statement saying the solution is correct, or does not contain critical error?\n\n{out}'
    messages_check = build_messages("", check_correctness_prompt)
    o = send_api_request(client, verify_model_name, messages_check)
    if verbose: print(f">>>>>>> Is verification good?\n{json.dumps(o, indent=4)}")
    
    bug_report = ""
    if "yes" not in o.lower():
        bug_report = extract_detailed_solution(out, "Detailed Verification", False)
    if verbose: print(f">>>>>>>Bug report:\n{json.dumps(bug_report, indent=4)}")
    
    return bug_report, o

def check_if_solution_claimed_complete(client, solution):
    check_complete_prompt = f"Is the following text claiming that the solution is complete?\n==========================================================\n\n{solution}\n\n==========================================================\n\nResponse in exactly \"yes\" or \"no\". No other words."
    messages = build_messages("", check_complete_prompt)
    o = send_api_request(client, MODEL_NAME, messages)
    print(o)
    return "yes" in o.lower()

def init_explorations(client, problem_statement, verbose=True, other_prompts=[]):
    messages = build_messages(initial_solution_prompt, problem_statement, other_prompts)
    print(f">>>>>> Initial prompt messages:\n{json.dumps(messages, indent=4)}")
    output1 = send_api_request(client, MODEL_NAME, messages)
    print(f">>>>>>> First solution: \n{json.dumps(output1, indent=4)}")

    conversation_history = [
        {"role": "user", "content": problem_statement},
        {"role": "assistant", "content": output1}
    ]

    print(f">>>>>>> Self improvement start:")
    messages2 = build_messages(initial_solution_prompt, self_improvement_prompt, previous_messages=conversation_history)
    solution = send_api_request(client, MODEL_NAME, messages2)
    print(f">>>>>>> Corrected solution: \n{json.dumps(solution, indent=4)}")
    
    print(f">>>>>>> Check if solution is complete:")
    is_complete = check_if_solution_claimed_complete(client, output1)
    if not is_complete:
        print(f">>>>>>> Solution is not complete. Failed.")
        return None, None, None, None
    
    print(f">>>>>>> Two-stage verification (physics precheck -> general).")
    verify, good_verify, pre_ok = run_two_stage_verification(
        client,
        problem_statement=problem_statement,
        solution=solution,
        physics_verify_model_name=MODEL_NAME,      
        general_verify_model_name=MODEL_NAME,
        verbose=verbose
    )
    print(f">>>>>>> Initial verification (bug report if any): \n{json.dumps(verify, indent=4)}")
    print(f">>>>>>> verify results: {good_verify} | physics_precheck_passed={pre_ok}")

    current_conversation = conversation_history + [
        {"role": "user", "content": self_improvement_prompt},
        {"role": "assistant", "content": solution}
    ]

    return current_conversation, solution, verify, good_verify

def agent(problem_statement, consecutive_verify, other_prompts=[]):
    client = get_openai_client(API_KEY, API_BASE_URL)
    conversation_history, solution, verify, good_verify = init_explorations(client, problem_statement, True, other_prompts)

    if solution is None:
        print(">>>>>>> Failed in finding a complete solution.")
        return None

    error_count = 0
    correct_count = 1
    for i in range(5): 
        print(f"Number of iterations: {i}, number of corrects: {correct_count}, number of errors: {error_count}")

        if "yes" not in good_verify.lower():
            correct_count = 0
            error_count += 1
            print(">>>>>>> Verification does not pass, correcting ...")

            base_conversation = [
                {"role": "assistant", "content": solution}
            ]

            messages = build_messages(initial_solution_prompt, problem_statement, other_prompts=[correction_prompt, verify], previous_messages=base_conversation)
            print(f">>>>>>> New prompt messages:\n{json.dumps(messages, indent=4)}")
            solution = send_api_request(client, MODEL_NAME, messages)
            
            conversation_history = base_conversation + [
                {"role": "user", "content": f"{correction_prompt}\n{verify}"},
                {"role": "assistant", "content": solution}
            ]

            print(f">>>>>>> Corrected solution:\n{json.dumps(solution, indent=4)}")
            print(f">>>>>>> Check if solution is complete:")
            is_complete = check_if_solution_claimed_complete(client, solution)
            if not is_complete:
                print(f">>>>>>> Solution is not complete. Failed.")
                return None
        print(f">>>>>>> Two-stage verification (physics precheck -> general).")
        verify, good_verify, pre_ok = run_two_stage_verification(
            client,
            problem_statement=problem_statement,
            solution=solution,
            physics_verify_model_name=MODEL_NAME,     
            general_verify_model_name=MODEL_NAME,
            verbose=True
        )

        if "yes" in good_verify.lower():
            print(">>>>>>> Solution is good, verifying again ...")
            correct_count += 1
            error_count = 0
        
        if correct_count >= consecutive_verify:
            print(">>>>>>> Correct solution found.")
            print(json.dumps(solution, indent=4))
            return solution
        elif error_count >= consecutive_verify:
            print(">>>>>>> Failed in finding a correct solution.")
            return None

    print(">>>>>>> Failed in finding a correct solution after max iterations.")
    return None

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='IPHO Problem Solver Agent')
    parser.add_argument('--model_name', required=True)
    parser.add_argument('--problem_file', nargs='?', help='Path to a single problem JSON file (deprecated; use --problem_path)')
    parser.add_argument('--problem_path', '-p', required=False,
                        help='Path to a problem JSON file or a directory containing multiple JSON files')
    parser.add_argument('--log', '-l', type=str,
                        help='Path to log file OR directory for batch (optional). If a directory is given in batch mode, logs/solutions go there.')
    parser.add_argument('--other_prompts', '-o', type=str, help='Other prompts (optional, comma-separated)')
    parser.add_argument("--max_runs", '-m', type=int, default=5, help='Maximum number of solver runs (default: 3 if fail)')
    parser.add_argument("--consecutive_verify", '-cv', type=int, default=2, help='The number of consecutive verification time (correct & wrong)')
    args = parser.parse_args()

    # Priority --problem_path > --problem_file
    path_arg = args.problem_path or args.problem_file
    if not path_arg:
        print("Error: Please provide --problem_path (file or directory), or --problem_file (single file).")
        sys.exit(1)

    p = Path(path_arg)
    if not p.exists():
        print(f"Error: Path not found: {p}")
        sys.exit(1)

    # create log
    batch_out_dir = Path(args.log)
    batch_out_dir.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] Log directory created/using: {batch_out_dir}")

    # single file 
    if p.is_file():
        success, sol_path = process_single_problem_file(
            problem_json_path=str(p),
            model_name=args.model_name,
            other_prompts=args.other_prompts,
            max_runs=args.max_runs,
            consecutive_verify=args.consecutive_verify,
            log_dir=batch_out_dir
        )
        if not success:
            sys.exit(1)
        sys.exit(0)

    # traverse *.json
    json_files = sorted([Path(f) for f in glob.glob(str(p / "*.json"))])
    if not json_files:
        print(f"No JSON files found under directory: {p}")
        sys.exit(1)

    all_ok = True
    for jf in json_files:
        print(f"\n===== Processing: {jf.name} =====")

        skip, found_log = _should_skip_by_log(str(jf), batch_out_dir)
        if skip:
            print(f"[SKIP] Detected existing log -> {found_log.name}. Skipping this problem.")
            continue

        ok, sol_path = process_single_problem_file(
            problem_json_path=str(jf),
            model_name=args.model_name,
            other_prompts=args.other_prompts,
            max_runs=args.max_runs,
            consecutive_verify=args.consecutive_verify,
            log_dir=batch_out_dir
        )
        if not ok:
            all_ok = False
            print(f"[FAILED] {jf.name}")
        else:
            print(f"[DONE] {jf.name} -> {sol_path}")

    sys.exit(0 if all_ok else 1)