import base64
import colorsys
import concurrent
import hashlib
import io
import math
import os
import shutil
import sys
import threading
import time
import uuid
from contextlib import contextmanager

import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from PIL import Image
from absl import flags
from absl import logging
from matplotlib import animation
from shapely.geometry import Polygon


def create_code_snapshot(snapshots_dir, project_dir=None):
    """Create a code snapshot for reproducibility.

    Args:
        snapshots_dir: Directory to create snapshot in.
        project_dir: Optional explicit project directory. If not provided,
            inferred from __file__. Use explicit path when running from
            worktrees to ensure symlinks are resolved correctly.
    """
    os.makedirs(snapshots_dir, exist_ok=True)
    timestamp = time.strftime('%Y%m%d_%H%M%S')
    snapshot_id = f"snapshot_{timestamp}_{uuid.uuid4().hex[:8]}"
    snapshot_path = os.path.join(snapshots_dir, snapshot_id)

    if project_dir is None:
        script_dir = os.path.abspath(os.path.dirname(__file__))
        project_dir = os.path.abspath(os.path.join(script_dir, "../.."))
    else:
        project_dir = os.path.abspath(project_dir)
    os.makedirs(snapshot_path, exist_ok=True)

    # copy code files
    code_exts = {".py", ".yml", ".yaml", ".json", ".txt"}
    to_copy = []
    for root, dirs, files in os.walk(project_dir):
        dirs[:] = [d for d in dirs if d not in
                   ['.git', '__pycache__', 'output', 'checkpoints', 'data',
                    'logs', 'venv', 'env', '.ipynb_checkpoints', 'snapshots']]
        rel = os.path.relpath(root, project_dir)
        dest = os.path.join(snapshot_path, rel)
        for f in files:
            if os.path.splitext(f)[1] in code_exts:
                to_copy.append(
                        (os.path.join(root, f), os.path.join(dest, f), dest))

    for _, _, d in to_copy:
        os.makedirs(d, exist_ok=True)

    copied = 0
    lock = threading.Lock()

    def cp(t):
        src, dst, _ = t
        shutil.copy2(src, dst)
        nonlocal copied
        with lock:
            copied += 1

    with concurrent.futures.ThreadPoolExecutor(
            max_workers=min(32, len(to_copy))) as ex:
        ex.map(cp, to_copy)

    # Ensure configs are included even when configs/ is a symlink.
    configs_src = os.path.join(project_dir, "configs")
    if os.path.isdir(configs_src):
        configs_dst = os.path.join(snapshot_path, "configs")
        shutil.copytree(configs_src, configs_dst, symlinks=False,
                        dirs_exist_ok=True)

    # Provide large data paths via symlink to avoid copying datasets.
    def _symlink_dir(link_name: str, target: str) -> None:
        link_path = os.path.join(snapshot_path, link_name)
        if os.path.lexists(link_path):
            return
        os.symlink(target, link_path)

    data_src = os.path.join(project_dir, "data")
    if os.path.isdir(data_src):
        _symlink_dir("data", os.path.realpath(data_src))

    return snapshot_path


@contextmanager
def time_block(msg: str):
    """Context manager: log elapsed time of its block."""
    t0 = time.perf_counter()
    yield
    logging.info("%s took %.2f s", msg, time.perf_counter() - t0)


def parse_valid_rectangle(points, grid=84):
    # If only two points are provided, assume they represent the diagonal corners
    if len(points) == 2:
        (x1, y1), (x2, y2) = points

        # Ensure correct ordering for bottom-left and top-right
        bottom_left = (min(x1, x2), min(y1, y2))
        top_right = (max(x1, x2), max(y1, y2))

        # Compute the other two corners
        bottom_right = (top_right[0], bottom_left[1])
        top_left = (bottom_left[0], top_right[1])

        points = [bottom_left, top_left, top_right, bottom_right]
    elif len(points) == 4:
        # Now make sure they are in clockwise order. Simply just need to sort the points
        points = sorted(points, key=lambda x: (x[0], x[1]))

        # Then switch order of last two points to make it clockwise
        points[2], points[3] = points[3], points[2]
    else:
        raise ValueError(
                'Expected 2 or 4 points for a rectangle, got %s' % points)

    # Validate the polygon
    polygon = Polygon(points)
    if (not polygon.is_valid or polygon.area == 0 or
            not math.isclose(polygon.minimum_rotated_rectangle.area,
                             polygon.area)):
        raise ValueError(
                "Invalid rectangle coordinates: {} do not form a valid rectangle".format(
                        points))

    return points


def brighten_colors(color_dict, brightness_factor=1.5):
    """
    Brighten the colors in the given color dictionary.

    Args:
        color_dict: Dictionary mapping node names to hex color strings.
        brightness_factor: Factor to increase brightness (values > 1.0 make colors brighter)

    Returns:
        A new dictionary with brightened colors.
    """
    new_color_dict = {}

    for node_name, color in color_dict.items():
        # Convert hex to RGB (0-1 scale)
        r = int(color[1:3], 16) / 255.0
        g = int(color[3:5], 16) / 255.0
        b = int(color[5:7], 16) / 255.0

        # Convert RGB to HSL (hue, saturation, lightness)
        h, l, s = colorsys.rgb_to_hls(r, g, b)

        # Increase lightness but cap it at 1.0
        new_l = min(1.0, l * brightness_factor)

        # Convert back to RGB
        new_r, new_g, new_b = colorsys.hls_to_rgb(h, new_l, s)

        # Convert back to hex format
        hex_color = "#{:02x}{:02x}{:02x}".format(
                int(new_r * 255),
                int(new_g * 255),
                int(new_b * 255)
        )

        new_color_dict[node_name] = hex_color

    return new_color_dict


def generate_short_names(macro_names):
    """
    Creates unique 3-character identifiers for macro names using capitals and numbers.
    """

    # Use uppercase letters first, then numbers
    chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'

    def to_base36(num):
        """Convert number to custom base encoding using our character set"""
        result = ''
        base = len(chars)
        while num:
            num, rem = divmod(num, base)
            result = chars[rem] + result
        return result or chars[0]  # Default to 'A' instead of '0'

    name_map = {}
    used_shorts = set()

    for name in macro_names:
        hash_obj = hashlib.md5(name.encode())
        hash_int = int.from_bytes(hash_obj.digest()[:3], 'big')
        short_name = to_base36(hash_int)[:3]

        original_short = short_name
        while short_name in used_shorts:
            last_char = short_name[-1]
            if last_char == '9':  # Wrap around to 'A'
                next_char = 'A'
            elif last_char == 'Z':  # Wrap around to '0'
                next_char = '0'
            else:
                next_char = chars[chars.index(last_char) + 1]
            short_name = original_short[:2] + next_char

        used_shorts.add(short_name)
        name_map[name] = short_name

    return name_map


def format_scientific(num, zero_threshold=1e-9):
    """
    Format the given number as scientific notation with 2 decimals,
    e.g. 2696798 -> 2.70e+06

    Values below zero_threshold are rounded to "0" to avoid confusing
    floating point noise like 8.59e-30.
    """
    if abs(num) < zero_threshold:
        return "0"
    return f"{num:.2e}"


def image_to_base64_str(image, format="PNG"):
    """Convert a PIL Image to base64-encoded string."""
    buf = io.BytesIO()
    image.save(buf, format=format)
    byte_im = buf.getvalue()
    return base64.b64encode(byte_im).decode("utf-8")


def downsize_image(image, max_size=800):
    """Downsize an image while maintaining aspect ratio"""
    ratio = max_size / max(image.size)
    if ratio < 1:  # Only shrink if the image is larger than max_size
        new_size = tuple(int(dim * ratio) for dim in image.size)
        return image.resize(new_size, Image.Resampling.LANCZOS)
    return image


def log_tensorboard_image(writer, tag, array_2d, global_step):
    """
    Takes a 2D numpy array (float) and logs it as a color-mapped RGB image.
    """
    # Ensure float
    data = array_2d.astype(np.float32)

    # Normalize data to [0, 1] for color mapping
    norm = colors.Normalize(vmin=data.min(), vmax=data.max(), clip=True)
    mapper = cm.ScalarMappable(norm=norm, cmap='viridis')

    # Apply colormap -> shape: (H, W, 4) in RGBA
    mapped_data = mapper.to_rgba(data)

    # Drop alpha channel -> shape: (H, W, 3)
    mapped_data = mapped_data[..., :3]

    # Reshape to (C, H, W) for TensorBoard
    mapped_data = np.moveaxis(mapped_data, -1, 0)  # -> (3, H, W)

    # Convert to torch tensor with shape (1, 3, H, W)
    img_tensor = torch.from_numpy(mapped_data[None, ...])

    # Log to TensorBoard
    writer.add_images(tag, img_tensor, global_step=global_step)


def expand_clockwise_coords(coordinates):
    # Find the bounds
    min_x = min(c[0] for c in coordinates)
    max_x = max(c[0] for c in coordinates)
    min_y = min(c[1] for c in coordinates)
    max_y = max(c[1] for c in coordinates)

    # Create rectangle corners in clockwise order starting from bottom-left
    expanded = [
            (min_x, min_y),  # bottom-left
            (max_x + 1, min_y),  # bottom-right
            (max_x + 1, max_y + 1),  # top-right
            (min_x, max_y + 1)  # top-left
    ]

    return expanded


def copy_ignoring_permissions(src, dst):
    # Create destination directory without copying permissions
    # Create destination directory without copying permissions
    os.makedirs(dst, exist_ok=True)

    # Copy all contents
    for item in os.listdir(src):
        s = os.path.join(src, item)
        d = os.path.join(dst, item)

        if os.path.isdir(s):
            # Recurse for directories, creating with default permissions
            copy_ignoring_permissions(s, d)
        else:
            # Copy only file contents
            shutil.copyfile(s, d)


def get_clamped_corners(corners, grid_size):
    processed_corners = []
    for x, y in corners:
        if not (0 <= x <= grid_size and 0 <= y <= grid_size):
            raise ValueError(
                    "Coordinate (%d, %d) outside grid bounds [0, %d]" % (x, y,
                                                                         grid_size))

        else:
            processed_corners.append((x, y))
    return processed_corners


def get_placeholder_mask(all_suggestions, grid_size, next_x, next_y):
    """
    Calculates a mask to prevent placing the current macro in regions
    reserved for future suggestions.

    Args:
        all_suggestions: A list of future region suggestions (corners).
        grid_size: The dimension of the placement grid (e.g., 84).
        next_x: The width of the *current* macro being placed (in grid units).
        next_y: The height of the *current* macro being placed (in grid units).

    Returns:
        A (grid_size, grid_size) numpy array where 0 indicates a forbidden
        placement location for the top-left corner of the current macro.
    """
    # Start with a mask that allows placement everywhere.
    mask = np.ones((grid_size, grid_size), dtype=np.float32)

    for corners in all_suggestions:
        if corners is None:
            continue

        # Ensure the corner coordinates are valid and within the grid.
        corners = get_clamped_corners(corners, grid_size)

        xs = [p[0] for p in corners]
        ys = [p[1] for p in corners]

        min_col, max_col = min(xs), max(xs)
        min_row, max_row = min(ys), max(ys)

        # Calculate the forbidden zone for the top-left corner of the current macro.
        # If the top-left corner is placed at `(startx, starty)`, the macro would
        # occupy the region up to `(startx + next_x - 1, starty + next_y - 1)`.
        # To avoid overlap with the reserved region `(min_x, min_y)` to `(max_x, max_y)`,
        # the top-left corner cannot be placed in a way that any part of the
        # macro enters the reserved region.
        start_col = max(0, min_col - next_x + 1)
        start_row = max(0, min_row - next_y + 1)
        end_col = min(grid_size, max_col)
        end_row = min(grid_size, max_row)

        # Mark the calculated forbidden zone as 0 (invalid).
        mask[start_row:end_row, start_col:end_col] = 0.0

    return mask


def animate_debug_images(
        debug_img_list,
        interval_ms=500,
        figsize=(6, 6),
        dpi=100,
        output_gif="debug_animation.gif",
):
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    frames = [np.array(img) for img in debug_img_list]
    im = ax.imshow(frames[0])
    ax.axis("off")

    def update(frame):
        im.set_data(frame)
        return (im,)

    ani = animation.FuncAnimation(
            fig, update, frames=frames, interval=interval_ms, blit=True,
            repeat=True
    )
    ani.save(output_gif, writer="pillow")


def override_flags_from_config(config_path):
    """Parses a YAML config file and overrides absl flags if keys match."""
    with open(config_path, "r") as f:
        config_dict = yaml.safe_load(f)

    for key, value in config_dict.items():

        if value is None:
            continue

        # Only override if 'key' is an existing absl flag
        if key in flags.FLAGS:
            flags.FLAGS[key].value = value
            logging.info("Overriding flag %s with value %s", key, value)


@contextmanager
def suppress_all_output():
    """Context manager to suppress all stdout/stderr output at the file descriptor level."""
    # Save original file descriptors
    old_stdout_fd = os.dup(sys.stdout.fileno())
    old_stderr_fd = os.dup(sys.stderr.fileno())

    # Open devnull
    devnull_fd = os.open(os.devnull, os.O_WRONLY)

    try:
        # Redirect stdout and stderr to devnull at the file descriptor level
        os.dup2(devnull_fd, sys.stdout.fileno())
        os.dup2(devnull_fd, sys.stderr.fileno())
        yield
    finally:
        # Restore original file descriptors
        os.dup2(old_stdout_fd, sys.stdout.fileno())
        os.dup2(old_stderr_fd, sys.stderr.fileno())

        # Close file descriptors
        os.close(devnull_fd)
        os.close(old_stdout_fd)
        os.close(old_stderr_fd)
