import os
import re, regex
import subprocess
import sys
import tempfile
import uuid
from pathlib import Path
from datetime import datetime

import torch
import nltk
from typing import Optional, Tuple

from examples.scripts.values_utils import extract_values, compare_values
from examples.scripts.chart_utils import extract_chart_type, compare_chart_type
from examples.scripts.layout_utils import extract_layout, compare_layout
from examples.scripts.text_utils import extract_title, compare_title, extract_labels, compare_labels

from openrlhf.utils.remote_rm_utils import extract_code_from_text, render_image_from_code

LOG_PATH = os.environ.get("REWARD_LOG_PATH", "reward.log")
TEMP_DIR = os.environ.get("TEMP_IMAGE_DIR", os.path.join(tempfile.gettempdir(), "chart_images"))

os.makedirs(TEMP_DIR, exist_ok=True)

choices = ["a", "b", "c", "d"]
problem_pattern = r"<\|im_start\|>user\n(.*?)<\|im_end\|>"
response_prefix = r"<\|im_start\|>assistant\n"


def get_response_from_query(q: str):
    ends_of_sentence = ["<|im_end|>", "<｜end▁of▁sentence｜>", "<|endoftext|>"]
    pos = re.search(response_prefix, q)
    if pos is None:
        return ""
    response = q[pos.end() :]
    for e in ends_of_sentence:
        response = response.replace(e, "")
    return response.strip()


def get_query_from_query(q: str):
    try:
        matches = re.findall(problem_pattern, q, re.DOTALL)
        return matches[0]
    except:
        return q


def accuracy_reward_func(completion, answer):
    completion_code = completion
    answer_code = answer
    
    # Initialize scores for each component
    scores = {
        "values": 0.0,
        "chart_type": 0.0,
        "layout": 0.0,
        "title": 0.0,
        "labels": 0.0,
    }
    
    # Weight settings for component importance
    weights = {
        "values": 0.4,
        "chart_type": 0.3,
        "layout": 0.1,
        "title": 0.1,
        "labels": 0.1,
    }

    # 1. Compare values
    completion_values = extract_values(completion_code)
    answer_values = extract_values(answer_code)
    scores["values"] = compare_values(completion_values, answer_values)
    
    # 2. Compare chart type
    completion_chart_type = extract_chart_type(completion_code)
    answer_chart_type = extract_chart_type(answer_code)
    scores["chart_type"] = compare_chart_type(completion_chart_type, answer_chart_type)
    
    # 3. Compare layout
    completion_layout = extract_layout(completion_code)
    answer_layout = extract_layout(answer_code)
    scores["layout"] = compare_layout(completion_layout, answer_layout)
    
    # 4. Compare title
    completion_title = extract_title(completion_code)
    answer_title = extract_title(answer_code)
    scores["title"] = compare_title(completion_title, answer_title)
    
    # 5. Compare labels
    completion_labels = extract_labels(completion_code)
    answer_labels = extract_labels(answer_code)
    scores["labels"] = compare_labels(completion_labels, answer_labels)
    
    # Calculate weighted total score
    total_score = sum(scores[key] * weights[key] for key in scores)
        
    return total_score, scores


def format_reward_func(completion, **kwargs):
    # matches = re.search(r"```python\s*(.+?)\s*```", completion, re.DOTALL)
    # return 0.5 if matches else 0.0

    # Extract code block
    code = extract_code_from_text(completion)
    if not code:
        return 0.0
    
    # Create unique temporary file path
    unique_id = str(uuid.uuid4())
    temp_img_path = os.path.join(TEMP_DIR, f"format_test_{unique_id}.png")
    
    try:
        # Try to execute the code
        success, _ = render_image_from_code(code, temp_img_path)
        return 0.5 if success else 0.0
    finally:
        # Clean up temporary files
        if os.path.exists(temp_img_path):
            os.remove(temp_img_path)


def reward_func(queries, prompts, labels):
    # queries is prompts + responses
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    rewards = []
    accuracy_rewards = []
    format_rewards = []
    
    with open(LOG_PATH, "a") as f:
        f.write(f"----------------------------- {current_time} -----------------------------\n")
        for query, prompt, answer in zip(queries, prompts, labels):
            try:
                response = get_response_from_query(query)
                if response == "":
                    f.write("Error: " + query + "\n")
                    rewards.append(0.0)
                    accuracy_rewards.append(0.0)
                    format_rewards.append(0.0)

                else:
                    query1 = get_query_from_query(query)
                    query1_tmp = query1.split("<|vision_end|>")[1]

                    accuracy_reward, component_scores = accuracy_reward_func(response, answer)
                    format_reward = format_reward_func(response)

                    rewards.append(accuracy_reward + format_reward)
                    accuracy_rewards.append(accuracy_reward)
                    format_rewards.append(format_reward)
                    f.write(f"===============================================================\n")
                    f.write("Query: " + query1_tmp + "\n")
                    f.write("Response: " + response + "\n")
                    f.write("Answer: " + answer + "\n")
                    f.write(f"Accuracy Reward: {accuracy_reward}\tFormat Reward: {format_reward}\n")
                    f.write(f"Component Scores: values={component_scores['values']:.4f}, "
                            f"chart_type={component_scores['chart_type']:.4f}, "
                            f"layout={component_scores['layout']:.4f}, "
                            f"title={component_scores['title']:.4f}, "
                            f"labels={component_scores['labels']:.4f}\n\n\n\n")
                    f.write(f"===============================================================\n")
            except:
                f.write("Error: " + query + "\n")
                rewards.append(0.0)
                accuracy_rewards.append(0.0)
                format_rewards.append(0.0)

    return {
        "rewards": torch.tensor(rewards, dtype=torch.float32),
        "accuracy_rewards": torch.tensor(accuracy_rewards, dtype=torch.float32),
        "format_rewards": torch.tensor(format_rewards, dtype=torch.float32),
    }
