import re
import sys
import os
import shutil
from importlib import import_module
from typing import Callable
import matplotlib.pyplot as plt
from llm_utils.base_imports import base_imports
import jax
import jax.numpy as jnp

def extract_python_code(response_text: str) -> str | None:
    """
    Extracts the Python code block enclosed in ```python ... ``` markers
    from a larger text string.

    It assumes the code block starts on a line immediately after '```python'
    and ends on the line immediately before '```'.

    Args:
        response_text: The string potentially containing the code block.

    Returns:
        The extracted Python code as a string, or None if no valid block
        is found. Returns the content of the *last* found block if multiple
        exist, matching the typical prompt structure where implementation
        is last.
    """
    # Use regex to find all ```python ... ``` blocks
    # re.DOTALL makes '.' match newlines as well
    # non-greedy '.*?' ensures it finds the shortest match for the content
    pattern = r"```python\s*\n(.*?)\n```"
    matches = re.findall(pattern, response_text, re.DOTALL)

    if matches:
        # Return the last found match, as the prompt structure
        # usually puts the final implementation code last.
        return matches[-1].strip() # Strip leading/trailing whitespace from the code itself
    else:
        # No match found
        return None

def get_reward_fn(reward_fn_code: str, tmp_dir: str = 'temp_rewards') -> Callable:
    def get_function(function_name, module_name):
        # Remove module from sys.modules if it exists
        
        if module_name in sys.modules:
            del sys.modules[module_name]
        
        try:
            # Remove pycache files if they exist
            pycache_dir = os.path.join(os.path.dirname(module_name), "__pycache__")
            if os.path.exists(pycache_dir):
                shutil.rmtree(pycache_dir)
            
            # Remove .pyc file if it exists
            pyc_file = f"{module_name}.pyc"
            if os.path.exists(pyc_file):
                os.remove(pyc_file)
            
            module = import_module(module_name)  # Now reimport the fresh module
            return getattr(module, function_name)
            
        except Exception as e:
            print(f"Error loading module {module_name}: {str(e)}")
            raise

    fn_name = "reward_fn"
    # make dir if not exists
    os.makedirs(tmp_dir, exist_ok=True)
    with open(tmp_dir + "/" + fn_name + ".py", "w") as f:
        f.write(base_imports + '\n' + reward_fn_code)
    custom_reward_fn = get_function(fn_name, tmp_dir + '.' + fn_name)
    return custom_reward_fn

def test_reward_fn(reward_fn):
    try:
        logits = jnp.array([0.0, 1.0, -1.0, 2.0, -2.0])
        reward = reward_fn(logits)
        return True
    except Exception as e:
        print(f"Error testing reward function: {str(e)}")
        return False
    
def plot_best(best_base,save_dir,best_sample=None):
    # plot returns 
    base_results = best_base.data['results']
    avg_returns_per_update = base_results['metrics']["returned_episode_returns"].mean(axis=(-1, -2, 0))  # flatten across steps & envs
    std_returns_per_update = base_results['metrics']["returned_episode_returns"].mean(axis=(-1, -2)).std(axis=0)
    plt.plot(avg_returns_per_update, label='Base')
    plt.fill_between(range(len(avg_returns_per_update)), avg_returns_per_update-std_returns_per_update, avg_returns_per_update+std_returns_per_update, alpha=0.2)
    if best_sample:
        sample_results = best_sample.data['results']
        avg_returns_per_update = sample_results['metrics']["returned_episode_returns"].mean(axis=(-1, -2, 0))  # flatten across steps & envs
        std_returns_per_update = sample_results['metrics']["returned_episode_returns"].mean(axis=(-1, -2)).std(axis=0)
        plt.plot(avg_returns_per_update, label='Best Sample')
        plt.fill_between(range(len(avg_returns_per_update)), avg_returns_per_update-std_returns_per_update, avg_returns_per_update+std_returns_per_update, alpha=0.2)
    
    plt.legend()
    plt.xlabel("Update")
    plt.ylabel("Average Return")
    plt.savefig(save_dir+'returns.png')
    