import time
import ray
import requests
import torch
import os
import re
import subprocess
import sys
import uuid
from datetime import datetime
import tempfile
import shutil
from pathlib import Path
import numpy as np

# Add project root directory to system path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))

from openrlhf.utils.logging_utils import init_logger

# Add environment variable to control whether to use rule-based evaluation
USE_RULE_BASED = os.environ.get("USE_RULE_BASED", "False").lower() in ("true", "1", "yes")
print(f"######### Using rule-based evaluation: {USE_RULE_BASED}")

# If using rule-based evaluation, import related modules
if USE_RULE_BASED:
    from examples.scripts.values_utils import extract_values, compare_values
    from examples.scripts.chart_utils import extract_chart_type, compare_chart_type
    from examples.scripts.layout_utils import extract_layout, compare_layout
    from examples.scripts.text_utils import extract_title, compare_title, extract_labels, compare_labels

logger = init_logger(__name__)

LOG_PATH = os.environ.get("REWARD_LOG_PATH", "reward.log")
TEMP_DIR = os.environ.get("TEMP_IMAGE_DIR", os.path.join(tempfile.gettempdir(), "render_images"))

# Ensure temporary directory exists
os.makedirs(TEMP_DIR, exist_ok=True)

remote_rm_url = os.environ.get("REMOTE_RM_URL", "localhost")
print(f"######### remote_rm_url: {remote_rm_url}")

from openai import OpenAI
# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key="YOUR_OPENAI_API_KEY"
openai_api_base = f"http://{remote_rm_url}:8080/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

# Prompt for evaluating chart similarity
modeleval_prompt = """You are an excellent judge at evaluating visualization chart plots. The first image (reference image) is created using ground truth matplotlib code, and the second image (AI-generated image) is created using matplotlib code generated by an AI assistant. Your task is to score how well the AI-generated plot matches the ground truth plot.

### Scoring Methodology:
The AI-generated image's score is based on the following criteria, totaling a score out of 100 points:

1. **Chart Types (20 points)** Does the AI-generated image include all chart types present in the reference image (e.g., line charts, bar charts, etc.)?
2. **Layout (10 points)** Does the arrangement of subplots in the AI-generated image match the reference image (e.g., number of rows and columns)?
3. **Text Content (20 points)** Does the AI-generated image include all text from the reference image (e.g., titles, annotations, axis labels), excluding axis tick labels?
4. **Data (20 points)** How accurately do the data trends in the AI-generated image resemble those in the original image and is the number of data groups the same as in the reference image?
5. **Style (20 points)** Does the AI-generated image match the original in terms of colors (line colors, fill colors, etc.), marker types (point shapes, line styles, etc.), legends, grids, and other stylistic details?
6. **Clarity (10 points)** Is the AI-generated image clear and free of overlapping elements?

### Evaluation:
Compare the two images head to head and provide a detailed assessment. Use the following format for your response:


---

Comments:
- Chart Types: ${your comment and subscore}
- Layout: ${your comment and subscore}
- Text Content: ${your comment and subscore}
- Data: ${your comment and subscore}
- Style: ${your comment and subscore}
- Clarity: ${your comment and subscore}

Score: ${your final score out of 100}

---

Please use the above format to ensure the evaluation is clear and comprehensive.
"""

problem_pattern = r"<\|im_start\|>user\n(.*?)<\|im_end\|>"
response_prefix = r"<\|im_start\|>assistant\n"

def get_response_from_query(q: str):
    ends_of_sentence = ["<|im_end|>", "<｜end▁of▁sentence｜>", "<|endoftext|>"]
    pos = re.search(response_prefix, q)
    if pos is None:
        return ""
    response = q[pos.end() :]
    for e in ends_of_sentence:
        response = response.replace(e, "")
    return response.strip()

def get_query_from_query(q: str):
    try:
        matches = re.findall(problem_pattern, q, re.DOTALL)
        return matches[0]
    except:
        return q

def extract_code_from_text(text):
    """Extract Python code block from text"""
    code_block = re.search(r"```python\s*(.+?)\s*```", text, re.DOTALL)
    if code_block:
        return code_block.group(1)
    return None

def render_image_from_code(code, save_path):
    """Save code to temporary file and execute to render image"""
    # Create temporary Python file
    file_path = os.path.join(TEMP_DIR, f"{Path(save_path).stem}.py")
    
    # Modify code to save image to specified path
    if "fig.write_image" in code:
        pattern = r'(fig\.write_image\s*\()([\'"].*?[\'"]|[a-zA-Z0-9_]+)(.*)(\))'
        modified_code = re.sub(pattern, f'\\1\'{save_path}\'\\3\\4', code)
    elif "fig.savefig" in code:
        pattern = r'(fig\.savefig\s*\()([\'"].*?[\'"]|[a-zA-Z0-9_]+)(.*)(\))'
        modified_code = re.sub(pattern, f'\\1\'{save_path}\'\\3\\4', code)
    elif "fig.show()" in code:
        modified_code = code.replace("fig.show()", f"fig.write_image('{save_path}')\n")
    elif "plt.savefig" in code:
        pattern = r'(plt\.savefig\s*\()([\'"].*?[\'"]|[a-zA-Z0-9_]+)(.*)(\))'
        modified_code = re.sub(pattern, f'\\1\'{save_path}\'\\3\\4', code)
    elif "plt.show()" in code:
        modified_code = code.replace("plt.show()", f"plt.savefig('{save_path}')\n")
    else:
        modified_code = code + f"\nplt.savefig('{save_path}')"
    
    # Save modified code to temporary file
    with open(file_path, 'w') as f:
        f.write(modified_code)
    
    # Execute code to render image
    try:
        result = subprocess.run(
            [sys.executable, file_path],
            capture_output=True,
            text=True,
            timeout=30
        )
        
        success = result.returncode == 0 and os.path.exists(save_path)
        if success:
            return True, "Success"
        else:
            error_msg = result.stderr if result.stderr else "Unknown error"
            return False, f"Failed: {error_msg}"
    except Exception as e:
        return False, f"Exception: {str(e)}"
    finally:
        # Delete temporary Python file
        if os.path.exists(file_path):
            os.remove(file_path)

def encode_image_to_base64(image_path):
    """Encode image file to base64 string"""
    if not os.path.exists(image_path):
        return None
    
    import base64
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

def modeleval_images(reference_img_path, generated_img_path, try_max_times=5):
    """Use model to evaluate similarity between two images"""
    # Convert images to base64 encoding
    reference_img_base64 = encode_image_to_base64(reference_img_path)
    generated_img_base64 = encode_image_to_base64(generated_img_path)
    
    if not reference_img_base64 or not generated_img_base64:
        logger.error("Failed to encode images to base64")
        return "Failed to encode images to base64"
    
    conversation = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text", 
                    "text": modeleval_prompt,
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{reference_img_base64}",
                    }
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{generated_img_base64}",
                    }
                }
            ],
        }
    ]

    for _ in range(try_max_times):
        try:
            chat_response = client.chat.completions.create(
                model="Qwen2.5-VL-72B-Instruct",
                messages=conversation,
                max_tokens=2048,
                temperature=0.,
            )
            response = chat_response.choices[0].message.content
            return response
        except Exception as e:
            logger.info(f"Unexpected error in modeleval_images: {e}")
        time.sleep(1)

    raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.")


def extract_score(evaluation_text):
    """Extract final score from evaluation text"""
    # Look for score in format "Score: X/100"
    final_score_pattern = r"Score:\s*(\d+(?:\.\d+)?)/100"
    final_match = re.search(final_score_pattern, evaluation_text)
    if final_match:
        return float(final_match.group(1))
    
    # If above format not found, try to find the last occurrence of "Score: X" format
    score_pattern = r"Score:\s*(\d+(?:\.\d+)?)"
    matches = list(re.finditer(score_pattern, evaluation_text))
    if matches:
        # Return the last matched score
        return float(matches[-1].group(1))
    
    return 0.0  # If score cannot be extracted, return 0

def rule_based_evaluation(completion_code, answer_code):
    """Rule-based evaluation function"""
    # Initialize component scores
    scores = {
        "values": 0.0,
        "chart_type": 0.0,
        "layout": 0.0,
        "title": 0.0,
        "labels": 0.0,
    }
    
    # Weight settings for component importance
    weights = {
        "values": 0.4,
        "chart_type": 0.3,
        "layout": 0.1,
        "title": 0.1,
        "labels": 0.1,
    }

    # 1. Compare values
    completion_values = extract_values(completion_code)
    answer_values = extract_values(answer_code)
    scores["values"] = compare_values(completion_values, answer_values)
    
    # 2. Compare chart type
    completion_chart_type = extract_chart_type(completion_code)
    answer_chart_type = extract_chart_type(answer_code)
    scores["chart_type"] = compare_chart_type(completion_chart_type, answer_chart_type)
    
    # 3. Compare layout
    completion_layout = extract_layout(completion_code)
    answer_layout = extract_layout(answer_code)
    scores["layout"] = compare_layout(completion_layout, answer_layout)
    
    # 4. Compare title
    completion_title = extract_title(completion_code)
    answer_title = extract_title(answer_code)
    scores["title"] = compare_title(completion_title, answer_title)
    
    # 5. Compare labels
    completion_labels = extract_labels(completion_code)
    answer_labels = extract_labels(answer_code)
    scores["labels"] = compare_labels(completion_labels, answer_labels)
    
    # Calculate weighted total score
    total_score = sum(scores[key] * weights[key] for key in scores)
    
    return total_score, scores

def evaluate_chart_code(prompt, completion, answer):
    """Evaluate code and calculate reward scores"""
    # Initialize variables
    model_chart_reward = 0.0
    model_evaluation = "Model evaluation not performed."
    rule_chart_reward = 0.0
    component_scores = {}
    format_reward = 0.0

    # If using rule-based evaluation, perform rule evaluation first
    if USE_RULE_BASED:
        rule_chart_reward, component_scores = rule_based_evaluation(completion, answer)
    
    # Try to extract code for model evaluation and format reward
    completion_code = extract_code_from_text(completion)
    answer_code = extract_code_from_text(answer)

    # Create unique filenames
    unique_id = str(uuid.uuid4())
    generated_img_path = os.path.join(TEMP_DIR, f"generated_{unique_id}.png")
    reference_img_path = os.path.join(TEMP_DIR, f"reference_{unique_id}.png")

    try:
        # If code is successfully extracted, try to render images
        if completion_code and answer_code:  # Only this condition differs from previous judgment
            # Try to render generated code (format check)
            gen_success, gen_message = render_image_from_code(completion_code, generated_img_path)
            
            # Format reward based on whether code can be successfully rendered
            format_reward = 0.5 if gen_success else 0.0
            
            # Only attempt model evaluation when generated image renders successfully
            if gen_success:
                # Render reference image
                ref_success, ref_message = render_image_from_code(answer_code, reference_img_path)
                
                # Only perform model evaluation when reference image also renders successfully
                if ref_success:
                    # Use model to evaluate image similarity
                    model_evaluation = modeleval_images(reference_img_path, generated_img_path)
                    model_chart_reward = extract_score(model_evaluation) / 100.0
                else:
                    model_evaluation = f"Reference image rendering failed: {ref_message}"
                    logger.error(f"Failed to render reference image")
            else:
                model_evaluation = f"Generated image rendering failed: {gen_message}"
                logger.error(f"Failed to render generated image")
        else:
            error_msg = "Failed to extract code" + (f" from completion" if not completion_code else "") + (f" from answer" if not answer_code else "")
            model_evaluation = error_msg
            logger.error(error_msg)

        # Determine final chart similarity score based on whether rule-based evaluation is used
        if USE_RULE_BASED:
            # If using rule-based evaluation, take average of both methods
            chart_reward = (model_chart_reward + rule_chart_reward) / 2.0
            evaluation_method = "average (model + rule-based)"
        else:
            # If not using rule-based evaluation, use model evaluation result directly
            chart_reward = model_chart_reward
            evaluation_method = "model-based only"

        # Prepare evaluation result output
        if USE_RULE_BASED:
            rule_based_info = f"Rule-based component scores: values={component_scores.get('values', 0):.4f}, " \
                             f"chart_type={component_scores.get('chart_type', 0):.4f}, " \
                             f"layout={component_scores.get('layout', 0):.4f}, " \
                             f"title={component_scores.get('title', 0):.4f}, " \
                             f"labels={component_scores.get('labels', 0):.4f}\n" \
                             f"Rule-based score: {rule_chart_reward:.4f}, "
        else:
            rule_based_info = ""
        
        evaluation_result = f"{model_evaluation}\n\n{rule_based_info}" \
                           f"Model-based score: {model_chart_reward:.4f}\n" \
                           f"Final score ({evaluation_method}): {chart_reward:.4f}"
        
        return chart_reward, format_reward, completion_code, evaluation_result
    
    finally:
        # Clean up temporary files
        for path in [reference_img_path, generated_img_path]:
            if os.path.exists(path):
                os.remove(path)

def reward_func(queries, prompts, labels, try_max_times=5, is_debug=False):
    """Main reward function"""
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    rewards = []
    chart_rewards = []
    format_rewards = []
    
    with open(LOG_PATH, "a", encoding="utf-8") as f:
        f.write(f"----------------------------- {current_time} -----------------------------\n")
        for query, prompt, answer in zip(queries, prompts, labels):
            try:
                response = get_response_from_query(query) if not is_debug else query
                if response == "":
                    f.write("Error: " + query + "\n")
                    rewards.append(0.0)
                    chart_rewards.append(0.0)
                    format_rewards.append(0.0)
                else:
                    if is_debug:
                        query_text_tmp = query_text = prompt
                    else:
                        query_text = get_query_from_query(query)
                        query_text_tmp = query_text.split("<|vision_end|>")[1]

                    # Evaluate code and get reward scores
                    chart_reward, format_reward, code_parsed, evaluation_result = evaluate_chart_code(query_text, response, answer)

                    # Total reward is sum of chart similarity reward and format reward
                    total_reward = chart_reward + format_reward
                    
                    rewards.append(total_reward)
                    chart_rewards.append(chart_reward)
                    format_rewards.append(format_reward)
                    
                    f.write(f"===============================================================\n")
                    f.write("Query: " + query_text_tmp + "\n")
                    f.write("Response: " + response + "\n")
                    f.write("Answer: " + answer + "\n")
                    f.write("Evaluation Result: " + evaluation_result + "\n")
                    f.write(f"Chart Reward: {chart_reward}\tFormat Reward: {format_reward}\tTotal Reward: {total_reward}\n\n\n\n")
                    f.write(f"===============================================================\n")
            except Exception as e:
                logger.error(f"Error in reward_func: {e}")
                f.write(f"Error: {query}\nException: {e}\n")
                rewards.append(0.0)
                chart_rewards.append(0.0)
                format_rewards.append(0.0)
    
    return {
        "rewards": torch.tensor(rewards, dtype=torch.float32),
        "accuracy_rewards": torch.tensor(chart_rewards, dtype=torch.float32),
        "format_rewards": torch.tensor(format_rewards, dtype=torch.float32),
    }

def remote_rm_fn(api_url, queries, prompts, labels, score_key="rewards", is_debug=False):
    """remote reward model API"""
    responses = reward_func(queries, prompts, labels, is_debug=is_debug)
    return responses


@ray.remote
def remote_rm_fn_ray(api_url, queries, prompts, labels, score_key="rewards", is_debug=False):
    return remote_rm_fn(api_url, queries, prompts, labels, score_key, is_debug)


if __name__ == "__main__":
    # test utils
    url = "http:xxx/get_rm_score"
    # score = remote_rm_fn(url, ["example query"], ["example response"])
    # print(score)

    answer = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_i_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_i_c_perp = [0.225, 0.23, 0.90, 0.9, 1.6]\nv_o_i = [1.77, np.nan, 2.42, 2.0, np.nan]\nv_o_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nv_o_iii = [1.1, np.nan, 1.36, 1.7, np.nan]\n\n# Define x-axis positions for each method\nx = np.arange(len(methods))\n\n# Create subplots for different barrier types\nfig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot for Ti_i barrier energies\nax1.plot(x, ti_i_c_parallel, label='Ti$_\\mathrm{i}$ $c_\\parallel$', marker='o', linestyle='-', color='skyblue')\nax1.plot(x, ti_i_c_perp, label='Ti$_\\mathrm{i}$ $c_\\perp$', marker='s', linestyle='--', color='salmon')\nax1.set_ylabel('Barrier Energy (eV)')\nax1.set_title('Barrier energies of Ti$_\\mathrm{i}$ ($c_\\parallel$ and $c_\\perp$) and V$_\\mathrm{O}$ (I, II, and III) paths')\nax1.legend(loc='upper left')\nax1.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax1\nfor i, txt in enumerate(ti_i_c_parallel):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i],txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_i_c_perp):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Plot for V_O barrier energies\nax2.plot(x, v_o_i, label='V$_\\mathrm{O}$ I', marker='^', linestyle='-', color='lightgreen')\nax2.plot(x, v_o_ii, label='V$_\\mathrm{O}$ II', marker='v', linestyle='--', color='orchid')\nax2.plot(x, v_o_iii, label='V$_\\mathrm{O}$ III', marker='d', linestyle=':', color='gold')\nax2.set_ylabel('Barrier Energy (eV)')\nax2.set_xticks(x)\nax2.set_xticklabels(methods, rotation=45, ha='right')\nax2.legend(loc='upper left')\nax2.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax2\nfor i, txt in enumerate(v_o_i):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_ii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_iii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Add a title to the entire figure\nfig.suptitle('Barrier Energies of Ti$_\\mathrm{i}$ and V$_\\mathrm{O}$ Paths by Various Methods', fontsize=14, y=1.02)\n\nplt.tight_layout()\nplt.show()\n\n```"
    completion = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_c_perpendicular = [0.23, 0.23, 0.9, 0.9, 1.6]\nvo_i = [1.77, np.nan, 2.42, 2.0, np.nan] # Use NaN for missing data\nvo_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nvo_iii = [1.10, np.nan, 1.36, 1.7, np.nan]\n\n# Create indices for plotting\nx_indices = np.arange(len(methods))\n\n# Create figure and subplots\nfig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot Ti barrier energies\naxes[0].plot(x_indices, ti_c_parallel, marker='o', linestyle='-', color='skyblue', label='Ti$_{\\mathrm{i}}$ c$_{\\parallel}$')\naxes[0].plot(x_indices, ti_c_perpendicular, marker='s', linestyle='--', color='salmon', label='Ti$_{\\mathrm{i}}$ c$_{\\perp}$')\n\n# Add values as text labels for Ti\nfor i, txt in enumerate(ti_c_parallel):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_parallel[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_c_perpendicular):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_perpendicular[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[0].set_ylabel('Barrier Energy (eV)')\naxes[0].set_title('Barrier energies of Ti$_{\\mathrm{i}}$ ($c_{\\parallel}$ and $c_{\\perp}$) and V$_{\\mathrm{O}}$ (I, II, and III) paths', fontsize=14)\naxes[0].legend()\naxes[0].grid(True, linestyle='--', alpha=0.6)\n\n# Plot VO barrier energies\naxes[1].plot(x_indices, vo_i, marker='^', linestyle='-', color='lightgreen', label='V$_{\\mathrm{O}}$ I')\naxes[1].plot(x_indices, vo_ii, marker='v', linestyle='--', color='orchid', label='V$_{\\mathrm{O}}$ II')\naxes[1].plot(x_indices, vo_iii, marker='d', linestyle=':', color='gold', label='V$_{\\mathrm{O}}$ III')\n\n# Add values as text labels for VO, handling NaN\nfor i, txt in enumerate(vo_i):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_i[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_ii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_ii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_iii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_iii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[1].set_ylabel('Barrier Energy (eV)')\naxes[1].set_xticks(x_indices)\naxes[1].set_xticklabels(methods, rotation=45, ha='right')\naxes[1].legend()\naxes[1].grid(True, linestyle='--', alpha=0.6)\n\n# Set overall title and adjust layout\nfig.suptitle('Calculated and Experimental Barrier Energies', fontsize=16, y=1.02)\nplt.tight_layout()\n\n# Show the plot\nplt.show()\n\n```"
    
    prompt = "This is an example of prompt."

    score = remote_rm_fn(url, [completion], [prompt], [answer], is_debug=True)
    print(score)
    
    # model-based only
    # {'rewards': tensor([1.4600]), 'accuracy_rewards': tensor([0.9600]), 'format_rewards': tensor([0.5000])}

    # average (model + rule-based)
    # {'rewards': tensor([1.4434]), 'accuracy_rewards': tensor([0.9434]), 'format_rewards': tensor([0.5000])}
