import warnings
import copy
import bisect

from typing import Union
from .metadata import EntryUpdater 

from typing import List, Optional
from dataclasses import dataclass
import bisect


@dataclass
class OffsetMapping:
    """A mapping from offsets in virtual memory (x) to offsets in a file (y)
    
    The mapping is a non-decreasing piecewise linear function.

    Args:
        breakpoints: x-values where the slope of the curve changes.
        slopes: Slopes of the curve at each breakpoint. All slopes must be non-negative.
        intercepts: y-values of the curve at each breakpoint.
    """
    breakpoints: List[float]
    slopes: List[float]
    intercepts: List[float]

    def __call__(self, x: float) -> float:
        """Evaluate the function at the given input
        
        Args:
            x: Input to the function
            
        Returns:
            The output of the function at `x`.
        """
        idx = bisect.bisect(self.breakpoints, x) - 1
        idx = max(idx, 0)
        breakpoint = self.breakpoints[idx]
        slope = self.slopes[idx]
        intercept = self.intercepts[idx]
        return slope * (x - breakpoint) + intercept

    def inverse(self, y: float) -> float:
        """Evaluate the generalized inverse
        
        Args:
            y: Output to invert.
        
        Returns:
            The smallest input `x` such that `self(x) = y`.
        """
        idx = bisect.bisect_left(self.intercepts, y) - 1
        if idx + 1 < len(self.intercepts) and self.intercepts[idx + 1] == y:
            idx = idx + 1
        if idx < 0 and self.slopes[0] == 0:
            return -float('inf')
        idx = max(idx, 0)
        slope = self.slopes[idx]
        intercept = self.intercepts[idx]
        breakpoint = self.breakpoints[idx]
        return (y - intercept)/slope + breakpoint if slope != 0 else breakpoint


class OffsetMappingUpdater(EntryUpdater[OffsetMapping]):
    def __init__(self, warn_replace: bool = False) -> None:
        """
        Args:
            warn_replace: Whether to emit a warning if bytes are replaced.
        """
        self.warn_replace = warn_replace

    def insert(self, entry: OffsetMapping, pos: int, value: bytes) -> Optional[OffsetMapping]:
        if value:
            new_metadata = copy.deepcopy(entry)

            x_star = entry.inverse(pos)
            x_star_prev = x_star - 1
            
            idx_star_prev = bisect.bisect_left(entry.breakpoints, x_star_prev)
            new_prev = False
            if idx_star_prev < len(entry.breakpoints) and entry.breakpoints[idx_star_prev] == x_star_prev:
                # Breakpoint exists at x_star_prev - update the model there
                new_metadata.slopes[idx_star_prev] = len(value)
                new_metadata.intercepts[idx_star_prev] = entry(x_star_prev)
            else:
                # Breakpoint does not exist at x__star_prev - create one
                new_metadata.breakpoints.insert(idx_star_prev, x_star_prev)
                new_metadata.slopes.insert(idx_star_prev, len(value))
                new_metadata.intercepts.insert(idx_star_prev, entry(x_star_prev))
                new_prev = True

            idx_star = bisect.bisect_left(entry.breakpoints, x_star)
            if idx_star >= len(entry.breakpoints) or entry.breakpoints[idx_star] != x_star:
                # Breakpoint does not exist at x_star - create one
                new_metadata.breakpoints.insert(idx_star + new_prev, x_star)
                new_metadata.slopes.insert(idx_star + new_prev, entry.slopes[idx_star - 1])
                new_metadata.intercepts.insert(idx_star + new_prev, entry(x_star))

            # Shift curve up by len(pos) for all x >= x_star
            for idx in range(idx_star + new_prev, len(new_metadata.breakpoints)):
                new_metadata.intercepts[idx] += len(value)

            entry = new_metadata

        return entry

    def delete(self, entry: OffsetMapping, pos: Union[int, range]) -> Optional[OffsetMapping]:
        pos = self._normalize_pos(pos)

        if pos.stop > pos.start:
            new_metadata = copy.deepcopy(entry)    
            
            # All outputs for x <= x_left should be unchanged post-deletion
            x_left = entry.inverse(pos.start - 1)
            # All outputs for x >= x_right should be translated down by len(pos) units post-deletion
            x_right = entry.inverse(pos.stop)
            
            idx_left = bisect.bisect(entry.breakpoints, x_left) - 1
            new_left = False
            if idx_left < 0 or entry.breakpoints[idx_left] != x_left:
                # Breakpoint does not exist at x_left - create one
                new_metadata.breakpoints.insert(idx_left + 1, x_left)
                new_metadata.slopes.insert(idx_left + 1, 0)
                new_metadata.intercepts.insert(idx_left + 1, pos.start - 1)
                new_left = True
            else:
                # Breakpoint does exist at x_left - update the model there
                new_metadata.slopes[idx_left] = 0
                new_metadata.intercepts[idx_left] = pos.start - 1

            idx_right = bisect.bisect(entry.breakpoints, x_right) - 1
            new_right = False
            if entry.breakpoints[idx_right] != x_right:
                # Breakpoint does not exist at x_right - create one
                new_metadata.breakpoints.insert(idx_right + 1 + new_left, x_right)
                new_metadata.slopes.insert(idx_right + 1 + new_left, entry.slopes[idx_right])
                new_metadata.intercepts.insert(idx_right + 1 + new_left, pos.stop)
                new_right = True
            
            # Shift curve down by len(pos) for all x >= x_right
            for idx in range(idx_right + new_right + new_left, len(new_metadata.breakpoints)):
                new_metadata.intercepts[idx] -= len(pos)
            
            # Remove breakpoints covering deleted positions in the original binary
            del new_metadata.breakpoints[(idx_left + new_left + 1):(idx_right + new_right + new_left)]
            del new_metadata.slopes[(idx_left + new_left + 1):(idx_right + new_right + new_left)]
            del new_metadata.intercepts[(idx_left + new_left + 1):(idx_right + new_right + new_left)]

            entry = new_metadata
        
        return entry
    
    def replace(self, entry: OffsetMapping, pos: Union[int, range], value: bytes) -> Optional[OffsetMapping]:
        pos = self._normalize_pos(pos)
        
        pos_len = pos.stop - pos.start
        if len(value) != pos_len:
            raise ValueError(f"`pos` to replace covers {pos_len} bytes, but `value` contains {len(value)} bytes")

        if pos.stop > pos.start and self.warn_replace:
            warnings.warn("Bytes were replaced. Mapping may no longer be valid.", RuntimeWarning)
            # Metadata does not need to be updated
        
        return entry