import urllib.parse
import boto3
import urllib
import uuid
import os
import shutil
import trimesh
import botocore
from PIL import Image, ImageFilter, ImageChops
import numpy as np
from enum import Enum
from openai import OpenAI
import json
import pathspec
import requests
import base64
from functools import lru_cache
import itertools
from tqdm import tqdm
import pyjson5
import re
import math

openai_client = None
if os.getenv('OPENAI_API_KEY') is not None:
    openai_client = OpenAI()

s3 = boto3.client('s3')

# This only works if run by the account that owns the S3 bucket!
sts = boto3.client('sts')
try:
    AWS_ACCOUNT_ID = sts.get_caller_identity()['Account']
except botocore.exceptions.NoCredentialsError as e:
    pass
    
class ProgramType(Enum):
    DSL = 1
    GRAPH = 2

def open_file(uri, mode='r'):
    """
    Opens a local or s3 uri file
    """
    if isinstance(uri, dict):
        return open_s3(uri['Bucket'], uri['Key'])
    elif uri.lower().startswith('s3://'):
        return open_s3_file(uri)
    else:
        return open(uri, mode)

def get_text(location_or_text: str):
    """
    Read a text file from disk or s3, or return the
    input string if it isn't a file.
    """

    if is_file(location_or_text):
        with open_file(location_or_text) as f:
            contents = f.read()
        if isinstance(contents, bytes):
            contents = contents.decode()
        return contents
    return location_or_text

def file_dir(filepath):
    return '/'.join(filepath.split('/')[:-1])

def directory_exists(dir):
    if dir.lower().startswith('s3://'):
        return directory_exists_s3(dir)
    else:
        return os.path.isdir(dir)

def remove_s3_dir(uri):
    bucket, prefix = parse_s3_uri(uri)
    s3resource = boto3.resource('s3')
    s3resource.Bucket(bucket).objects.filter(Prefix=prefix).delete()

def remove_dir(location):
    if location.lower().startswith('s3://'):
        remove_s3_dir(location)
    else:
        shutil.rmtree(location, ignore_errors=True)

def directory_exists_s3(dir):
    bucket, key = parse_s3_uri(dir)
    key = key.rstrip('/') + '/' 
    resp = s3.list_objects(Bucket=bucket, Prefix=key, Delimiter='/',MaxKeys=1)
    return 'Contents' in resp

def open_s3_file(uri):
    uri = uri[5:]
    uri_parts = uri.split('/')
    bucket = uri_parts[0]
    key = '/'.join(uri_parts[1:])
    obj = s3.get_object(Bucket=bucket, Key=key)
    return obj['Body']

def open_s3(Bucket, Key):
    return s3.get_object(Bucket=Bucket, Key=Key)['Body']

def write_s3(bucket, key, data):
    s3.put_object(Body=data, Bucket=bucket, Key=key)

def parse_s3_uri(uri, key=None):
    parsed = urllib.parse.urlparse(uri, allow_fragments=False)
    bucket = parsed.netloc
    key = parsed.path.lstrip('/')
    return bucket, key

def is_file_s3(uri):
    bucket, key = parse_s3_uri(uri)
    try:
        obj = s3.head_object(Bucket=bucket, Key=key)
        if 'application/x-directory' in obj['ContentType']:
            return False
    except botocore.exceptions.ClientError as e:
        return False
    return True

def read_json(location):
    with open_file(location) as f:
        return json.load(f)

def is_file(location):
    if location.startswith('s3://'):
        return is_file_s3(location)
    else:
        return os.path.isfile(location)

def list_all_s3(uri, sorted=True):
    bucket, key = parse_s3_uri(uri)
    responses = []
    resp = s3.list_objects_v2(Bucket=bucket, Prefix=key)
    responses.append(resp)
    while resp['IsTruncated']:
        ctoken = resp['NextContinuationToken']
        resp = s3.list_objects_v2(Bucket=bucket, Prefix=key, ContinuationToken=ctoken)
        responses.append(resp)
    keys = [item['Key'].lstrip('/') for resp in responses for item in resp['Contents']]
    if sorted:
        keys.sort()
    return [f's3://{bucket}/{k}' for k in keys]

def list_all(dir, sorted = True):
    if is_file(dir):
        with open_file(dir) as f:
            contents = f.read()
            if isinstance(contents, bytes):
                contents = contents.decode()
        lines = contents.strip().split('\n')
        lines = [l.strip() for l in lines]
        if sorted:
            lines.sort()
        return lines
    if dir.startswith('s3://'):
        return list_all_s3(dir, sorted)
    abs_dir = os.path.abspath(dir)
    paths = [os.path.join(root, filename) for root, dirnames, filenames in os.walk(abs_dir) for filename in filenames]
    if sorted:
        paths.sort()
    return paths

def list_all_filtered(dir, pass_filter=['*'], reject_filter=[], sorted = True):
    all_files = list_all(dir, sorted)
    pass_spec = pathspec.PathSpec.from_lines('gitwildmatch', pass_filter)
    reject_spec = pathspec.PathSpec.from_lines('gitwildmatch', reject_filter)
    return list(reject_spec.match_files(pass_spec.match_files(all_files), negate=True))

class TempDir:
    def __init__(self, root='/tmp/', max_attempts = 10):
        self.root = root
        id = str(uuid.uuid1())
        tempdir = root + id + '/'
        attempts = 0
        while os.path.exists(tempdir) and attempts < max_attempts:
            id = str(uuid.uuid1())
            tempdir = root + id + '/'
            attempts += 1
        self.path = tempdir
    def __enter__(self):
        os.makedirs(self.path, exist_ok=True)
        return self.path
    def __exit__(self, exc_type, exc_value, exc_tb):
        shutil.rmtree(self.path, ignore_errors=True)

def tempdir() -> str:
    id = str(uuid.uuid1())
    tempdir = f'/tmp/{id}/'

def copy_file(src: str, dst: str):
    # Copy a file between two locations,
    # either local or s3
    s3src = src.lower().startswith('s3://')
    s3dst = dst.lower().startswith('s3://')

    if s3src:
        src_bucket, src_key = parse_s3_uri(src)
    if s3dst:
        dst_bucket, dst_key = parse_s3_uri(dst)
    if s3src and s3dst:
        # S3 Copy
        s3.meta.client.copy({'Bucket':src_bucket, 'Key':src_key}, dst_bucket, dst_key)
    elif s3src and not s3dst:
        # Download
        with open(dst, 'wb') as f:
            f.write(open_s3(src_bucket, src_key).read())
    elif not s3src and s3dst:
        # Upload
        with open(src, "rb") as f:
            s3.upload_fileobj(f, dst_bucket, dst_key)
    else: # Both local
        shutil.copy2(src, dst)

def move_s3_folder(src, dst):
    """
    Args:
        bucket_name (str): The name of the S3 bucket.
        src (str): The source folder path (e.g., 'source-folder/').
        dst (str): The target folder path (e.g., 'target-folder/').

    Returns:
        None
    """

    bucket_name, source_folder = parse_s3_uri(src)
    _, target_folder = parse_s3_uri(dst)
    
    # Ensure folder paths end with a '/'
    if not source_folder.endswith('/'):
        source_folder += '/'
    if not target_folder.endswith('/'):
        target_folder += '/'

    s3 = boto3.client('s3')

    # List all objects in the source folder
    response = s3.list_objects_v2(Bucket=bucket_name, Prefix=source_folder)

    if 'Contents' not in response:
        return

    for obj in response['Contents']:
        source_key = obj['Key']
        target_key = source_key.replace(source_folder, target_folder, 1)

        s3.copy_object(Bucket=bucket_name, CopySource={'Bucket': bucket_name, 'Key': source_key}, Key=target_key)
        s3.delete_object(Bucket=bucket_name, Key=source_key)


def download_s3_folder(bucket_name, s3_folder, local_dir=None):
    """
    Download the contents of a folder directory
    Args:
        bucket_name: the name of the s3 bucket
        s3_folder: the folder path in the s3 bucket
        local_dir: a relative or absolute directory path in the local file system
    """
    """
    From: https://stackoverflow.com/questions/49772151/download-a-folder-from-s3-using-boto3
    """
    bucket = s3.Bucket(bucket_name)
    for obj in bucket.objects.filter(Prefix=s3_folder):
        target = obj.key if local_dir is None \
            else os.path.join(local_dir, os.path.relpath(obj.key, s3_folder))
        if not os.path.exists(os.path.dirname(target)):
            os.makedirs(os.path.dirname(target))
        if obj.key[-1] == '/':
            continue
        bucket.download_file(obj.key, target)

def copy_dir(src: str, dst: str):
    # Copy a file between two locations,
    # either local or s3
    s3src = src.lower().startswith('s3://')
    s3dst = dst.lower().startswith('s3://')

    if s3src:
        src_bucket, src_key = parse_s3_uri(src)
    if s3dst:
        dst_bucket, dst_key = parse_s3_uri(dst)
    if s3src and s3dst:
        # S3 Copy
        assert(False) # Can't copy s3 recursively, but should not need to
        s3.meta.client.copy({'Bucket':src_bucket, 'Key':src_key}, dst_bucket, dst_key)
    elif s3src and not s3dst:
        # Download All
        download_s3_folder(src_bucket, src_key, dst)
    elif not s3src and s3dst:
        # Upload
        upload_dir_to_s3(src, dst)
    else: # Both local
        os.makedirs(dst, exist_ok=True)
        shutil.copytree(src, dst, dirs_exist_ok=True)

def upload_dir_to_s3(src_dir: str, dst_uri: str):
    src_dir = src_dir.rstrip('/\\') + '/'
    dst_bucket, dst_key = parse_s3_uri(dst_uri)
    dst_key = dst_key.rstrip('/') + '/'
    for dirpath, _, filenames in os.walk(src_dir):
        for filename in filenames:
            path = os.path.join(dirpath, filename).replace('\\', '/')
            local_path = path[len(src_dir):]
            with open(path, 'rb') as f:
                s3.upload_fileobj(f, dst_bucket, dst_key + local_path)
            

def path_append(path: str, name: str):
    return path.rstrip('/\\') + '/' + name.lstrip('/\\')

def load_mesh(location: str):
    mesh_location = path_append(location, 'thickened_mc.obj')
    #assert(file_exists(mesh_location))
    with open_file(mesh_location) as f:
        return trimesh.load(f, 'obj')

def load_voxel_mesh(location: str):
    mesh_location = path_append(location, 'vox_surface.obj')
    #assert(file_exists(mesh_location))
    with open_file(mesh_location) as f:
        return trimesh.load(f, 'obj')

def load_image_bytes(location: str):
    return try_image_cache(location)

def try_image_cache(image_src, return_bytes = True):
    """
    This logic is bad and will break llava generation for
    non-cached images. Keeping it during the deadline, but
    it should really return None if it is not returning bytes
    and there is also not a cached image, then the llava
    image formatter should catch this.

    More appropriately we should just have an image or file
    cache in general, and have a place for general urls
    """
    if '.s3.' in image_src:
        image_src = url_to_s3_file(image_src)
    s3_cache_path = os.getenv('S3CACHE')
    image_bytes = None
    cache_name = None
    if s3_cache_path:
        bucket, key = parse_s3_uri(image_src)
        key = key.strip('/')
        bucket = bucket.strip('/')
        key_type = key.split('.')[-1]
        cache_name = bucket + '/' + key
        cache_loc = os.path.join(s3_cache_path, key_type, cache_name)
        if os.path.exists(cache_loc):
            if return_bytes:
                with open(cache_loc, 'rb') as f:
                    image_bytes = f.read()
        else:
            os.makedirs(os.path.dirname(cache_loc), exist_ok=True)
            image_bytes = open_file(image_src).read()
            with open(cache_loc, 'wb') as f:
                f.write(image_bytes)
    elif return_bytes:
        image_bytes = open_file(image_src).read()

    return image_bytes if return_bytes else cache_name

def load_image(location: str):
    if os.path.exists(location):
        return Image.open(location)
    with open_file(location) as f:
        return Image.open(f)

def clone_test(vis1, vis2):
    diff = ImageChops.difference(vis1, vis2)
    filtered_diff = diff.filter(ImageFilter.MedianFilter(7))
    channel_values = np.asarray(filtered_diff).flatten()
    other_diff = float(channel_values[channel_values > 20].sum() / len(channel_values))
    return other_diff < .1

def set_clone_test(*materials):
    all_renders= [visualize_material(mat) for mat in materials]
    clones = []
    for i,j in tqdm(list(itertools.combinations(range(len(all_renders)), 2))):
        if clone_test(all_renders[i], all_renders[j]):
            clones.append((i,j))
    
def renders_equal(material, *others):
    material_render = visualize_material(material)
    comparisons = []
    for other in others:
        other_render = visualize_material(other)
        diff = ImageChops.difference(material_render, other_render)
        filtered_diff = diff.filter(ImageFilter.MedianFilter(7))
        channel_values = np.asarray(filtered_diff).flatten()
        other_diff = float(channel_values[channel_values > 20].sum() / len(channel_values))
        comparisons.append(other_diff < .1)
    return comparisons

def download_and_convert_base64(image_url: str, as_url=True) -> str:
    """
    Downloads a PNG image from the specified URL and returns the image data 
    as a Base64-encoded data URL string.
    
    :param image_url: URL pointing to the PNG image
    :return: A string representing the Base64-encoded data URL of the image
    """

    if '.s3.' in image_url:
        image_url = url_to_s3_file(image_url)
    
    if image_url.startswith('s3://'):
        image_data = load_image_bytes(image_url)
    else:
        # Download the image
        response = requests.get(image_url)
        response.raise_for_status()  # Raises an HTTPError if the download failed
        # Encode the image content in base64
        image_data = response.content
    encoded_image = base64.b64encode(image_data).decode('utf-8')

    # Return as a data URL
    if as_url:
        return f"data:image/png;base64,{encoded_image}"
    else:
        return encoded_image

def format_prompt_openai(prompt, b64_images=False):
    """
    Formats a prompt into "contents" for use with the openai API.
    Assumes that images are specific as image_urls delimited by
    <[ image_url ]>
    """
    rest = prompt

    contents = []

    while '<[' in rest:
        start = rest.index('<[')
        end = rest.index(']>')
        if start > 0:
            contents.append({
                "type": "text",
                "text": rest[:start]
            })
        image_url = rest[start+2:end]
        if b64_images:
            image_url = download_and_convert_base64(image_url)
        contents.append({
          "type": "image_url",
          "image_url": {
            "url": image_url
        }})
        rest = rest[end+2:]
    contents.append({
        "type":"text",
        "text": rest
    })

    return contents

class LLM(Enum):
    GPT = 1
    NOVA = 2
    LLAVA = 3
    O1 = 4


def text_content(text: str, llm: LLM = LLM.GPT) -> dict:
    if llm == LLM.GPT or llm == LLM.O1:
        return {
            "type": "text",
            "text": text
        }
    elif llm == LLM.NOVA:
        return {
            "text": text
        }
    else:
        return {
            "type": "text",
            "text": text
        }

def image_content(image_src: str, llm: LLM = LLM.GPT, inference=False) -> dict:
    if llm == LLM.GPT or llm == LLM.O1:
        return {
          "type": "image_url",
          "image_url": {
            "url": download_and_convert_base64(image_src, as_url=True)
        }}
    elif llm == LLM.NOVA:
        if inference:
            return {
                "image": {
                    "format": "png",
                    "source": {"bytes": download_and_convert_base64(image_src, as_url=False)}
                }
            }
        else:
            return {
                "image": {
                    "format": "png",
                    "source": {
                        "s3Location": {
                            "uri": url_to_s3_file(image_src),
                            "bucketOwner": AWS_ACCOUNT_ID
                        }
                    }
                }
            }
    else:
        if image_src.lower().startswith('s3://') or '.s3.' in image_src:
            if '.s3.' in image_src:
                image_src = url_to_s3_file(image_src)
            s3_cache_path = os.getenv('S3CACHE')
            if s3_cache_path:
                cache_name = try_image_cache(image_src, return_bytes = False)
                return {"type": "image","source": cache_name}
        return { "type": "image", "source": image_src }

def format_prompt_contents(prompt, llm: LLM = LLM.GPT, inference=False):
    """
    Formats a prompt into "contents" for use with an llm.
    Assumes that images are specific as image_urls delimited by
    <[ image_url ]>
    """
    rest = prompt

    contents = []

    while '<[' in rest:
        start = rest.index('<[')
        end = rest.index(']>')
        if start > 0:
            contents.append(text_content(rest[:start],llm))
        image_url = rest[start+2:end]
        contents.append(image_content(image_url, llm, inference))
        rest = rest[end+2:]
    contents.append(text_content(rest,llm))

    return contents


def run_prompt_openai(
        messages, 
        model="gpt-4o-mini",
        temperature=1,
        max_completion_tokens=4096,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        n=1,
        return_response=False
        ):
    response = openai_client.chat.completions.create(
    model=model,
    messages=messages,
    response_format={
      "type": "text"
    },
    temperature=temperature,
    max_completion_tokens=max_completion_tokens,
    top_p=top_p,
    frequency_penalty=frequency_penalty,
    presence_penalty=presence_penalty,
    n=n
  )
    if return_response:
        return response
    return response.choices[0].message.content

def format_for_openai_batch(messages, 
        model="gpt-4o-mini",
        temperature=1,
        max_completion_tokens=4096,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        n=1):
    return json.dumps({
    "model": model,
    "messages": messages,
    "response_format": {
        "type": "text"
    },
    "temperature": temperature,
    "max_completion_tokens": max_completion_tokens,
    "top_p": top_p,
    "frequency_penalty": frequency_penalty,
    "presence_penalty": presence_penalty,
    "n": n
    })

def url_to_s3_file(url):
    parsed = urllib.parse.urlparse(url, allow_fragments=False)
    bucket = parsed.netloc.split('.')[0].rstrip('/')
    key = parsed.path.lstrip('/')
    return f"s3://{bucket}/{key}"

def url_to_s3(url):
    """
    IMORTANT: This gets the s3 DIRECTORY that contains the url, not the file itself
    """
    parsed = urllib.parse.urlparse(url, allow_fragments=False)
    bucket = parsed.netloc.split('.')[0]
    key = parsed.path
    key = file_dir(key).lstrip('/').rstrip('/')
    return f"s3://{bucket}/{key}"

def uri_to_url(uri):
    bucket, key = parse_s3_uri(uri)
    return f"https://{bucket}.s3.us-east-1.amazonaws.com/{key}"

def rendered_urls(uri):
    uri = uri.rstrip('/')
    url = uri_to_url(uri).rstrip('/')
    im_paths = {n: url + '/' + n + '.png' for n in ['top', 'front', 'right', 'top_right']}
    return im_paths

@lru_cache(maxsize=5000)
def graph_representation(location):
    location = location.rstrip('/') + '/'
    with open_file(location + 'graph.json') as f:
        graph = json.load(f)
    graph_txt = json.dumps(graph, indent=2)
    return graph_txt

@lru_cache(maxsize=5000)
def code_representation(location):
    location = location.rstrip('/') + '/'
    with open_file(location + 'code.py') as f:
        contents = f.read()
    if isinstance(contents, bytes):
        contents = contents.decode()
    return contents

@lru_cache(maxsize=5000)
def get_representation(location, representation: ProgramType = ProgramType.GRAPH):
    return graph_representation(location) if representation == ProgramType.GRAPH else code_representation(location)

def read_text_file(path):
    with open_file(path) as f:
        contents = f.read()
    if isinstance(contents, bytes):
        contents = contents.decode(encoding='utf-8')
    return contents

@lru_cache(maxsize=5000)
def get_description(location):
    p = path_append(location, 'description.txt')
    return read_text_file(p)


def list_all_programs(
        data_location: str, 
        program_type: ProgramType = ProgramType.GRAPH
):
    program_pattern = ["**/graph.json"] if program_type == ProgramType.GRAPH else ["**/code.py"]
    program_spec = pathspec.PathSpec.from_lines('gitwildmatch', program_pattern)
    
    all_programs = program_spec.match_files(list_all(data_location))
    all_examples = [file_dir(program) for program in all_programs]

    return all_examples

def visualize_material(location):
    ex = location.rstrip('/')
    im_paths = {n: ex + '/' + n + '.png' for n in ['top', 'front', 'right', 'top_right']}
    ims = {k:load_image(v) for k,v in im_paths.items()}
    dst = Image.new('RGB', (1024, 1024))
    dst.paste(ims['top_right'], (0, 0))
    dst.paste(ims['top'], (512, 0))
    dst.paste(ims['front'],(0,512))
    dst.paste(ims['right'], (512,512))
    return dst

def merge_defaults(args: dict, default_args: dict) -> dict:
    for k,v in default_args.items():
        if k in args:
            if isinstance(v, dict):
                args[k] = merge_defaults(args[k], default_args[k])
            else:
                continue
        else:
            args[k] = default_args[k]
    return args


def round_to_delineation(value, delineation):
    if delineation not in [0.1, 0.5, 0.05]:
        raise ValueError("Delineation must be either 0.1, 0.5, or 0.05")
    rounded_value = round(value / delineation) * delineation
    return f"{rounded_value:.2f}" if delineation == 0.05 else f"{rounded_value:.1f}"
def load_and_format_properties(example):
    sim_path = path_append(example, 'structure_info.json')
    sim_data = read_json(sim_path)
    E = round_to_delineation(sim_data['sim_E_VRH'], 0.1)
    K = round_to_delineation(sim_data['sim_K_VRH'], 0.5)
    G = round_to_delineation(sim_data['sim_G_VRH'], 0.05)
    nu = round_to_delineation(sim_data['sim_nu_VRH'], 0.1)
    rho = round_to_delineation(sim_data['sim_occupied_volume'], 0.1)

    return {
        'E':E,
        'K':K,
        'G':G,
        'nu':nu,
        'rho':rho
    }

def parse_record_id(record_id, examples = None):
    """
    Parse an 11 character (QRTIINNNNNN) record ID into:
    Query Type (Q): T = Train, Z = Test, N = KNN, R = KRN 
    Rep Type (R): D = DSL, G = Graph
    Task: G = Generate, P = Predict
    Inputs (I): Im = Images, Co = Code, Pr = Properties, IC = Images+Code, IP = Images+Properties
    Id Number (N): 6 digit ID number

    Returns a dictionary with the parsed values as enums
    """
    query_types = {
        'T': 'Train',
        'Z': 'Test',
        'N': 'KNN',
        'R': 'KRN'
    }
    rep_types = {
        'D': 'DSL',
        'G': 'Graph'
    }
    tasks = {
        'G': 'Generate',
        'P': 'Predict'
    }
    inputs = {
        'Im': 'Images',
        'Co': 'Code',
        'Pr': 'Properties',
        'IC': 'Images+Code',
        'IP': 'Images+Properties'
    }

    query_type = query_types.get(record_id[0], 'Unknown')
    rep_type = rep_types.get(record_id[1], 'Unknown')
    task = tasks.get(record_id[2], 'Unknown')
    input_type = inputs.get(record_id[3:5], 'Unknown')
    id_number = int(record_id[5:])
    
    if examples:
        return {
            'query_type': query_type,
            'rep_type': rep_type,
            'task': task,
            'input_type': input_type,
            'id_number': id_number,
            'example': examples[id_number]
        }
    else:
        return {
            'query_type': query_type,
            'rep_type': rep_type,
            'task': task,
            'input_type': input_type,
            'id_number': id_number
        }

def load_jsonl(location):
    with open_file(location) as f:
        lines = f.readlines()
    return [json.loads(l) for l in lines]


def extract_blocks_broken_ticks(text):
    """
    Extracts triple backtick blocks from a string
    """
    return re.findall(r'```([^`]+)```', text)

def extract_blocks(text):
    """
    Extracts triple backtick blocks from a string
    """
    remaining_text = text
    text_blocks = []
    while "```" in remaining_text:
        start = remaining_text.index("```")
        if "```" not in remaining_text[start + 3:]:
            break
        end = remaining_text.index("```", start + 3)
        text_blocks.append(remaining_text[start + 3:end])
        remaining_text = remaining_text[end + 3:]
    return text_blocks

def extract_and_classify_blocks(text, langs=['python','json']):
    """
    Extracts triple backtick blocks from a string and classifies them by language
    """
    blocks = extract_blocks(text)
    return {lang: [b[len(lang):].strip() for b in blocks if b.startswith(lang)] for lang in langs}


def parse_json(json_str):
    """
    Parses a JSON string into a dictionary, can handle both C-style and Python/Markdown-style comments.
    First looks line-by-line for # comments, then removes them and tries parsing with pyjson5.decode
    """
    lines = json_str.split('\n')
    for i,l in enumerate(lines):
        if '#' in l:
            lines[i] = l[:l.index('#')]
    return pyjson5.decode('\n'.join(lines))

import ast
def exec_with_return(code: str, env:dict):#, locals:dict):
    """
    Executes code and returns the value of the last expression.
    From:
    https://stackoverflow.com/questions/33908794/get-value-of-last-expression-in-exec-call
    """
    a = ast.parse(code)
    last_expression = None
    if a.body:
        if isinstance(a_last := a.body[-1], ast.Expr):
            last_expression = ast.unparse(a.body.pop())
        elif isinstance(a_last, ast.Assign):
            last_expression = ast.unparse(a_last.targets[0])
        elif isinstance(a_last, (ast.AnnAssign, ast.AugAssign)):
            last_expression = ast.unparse(a_last.target)
    exec(ast.unparse(a), env)
    if last_expression:
        return eval(last_expression, env)
    

def format_float(x):
    abs_x = abs(x)
    if abs_x != 0 and (abs_x >= 100 or abs_x < 0.1):
        # Format to 2 sig figs in scientific notation
        s = f"{x:.1e}"
        # Remove leading zeros in exponent
        s = re.sub(r'e([+-])0*(\d+)', r'e\1\2', s)
        return s
    else:
        return f"{x:.2g}"
    
def round_to_1_sig_fig(x):
    if abs(x) < 1e-6:
        return 0
    return round(x, -int(math.floor(math.log10(abs(x)))))