import torch
from torchtyping import TensorType
import torch
import math
from einops import rearrange


def get_power_spectrum(
    cortical_sheet: TensorType["height", "width", "e"]
) -> TensorType["height", "width", "e"]:
    assert cortical_sheet.ndim == 3, f"Expected 3 dims but got: {cortical_sheet.ndim}"
    # Perform Fourier transform on each channel (e) of the input image
    fft_image = torch.fft.fft2(cortical_sheet, dim=(0, 1))

    # Calculate the power spectrum
    power_spectrum = torch.abs(fft_image) ** 2
    # Get the dimensions of the power spectrum
    height, width = power_spectrum.shape[0], power_spectrum.shape[1]

    ## offset corner to the center
    offset_spectrum = torch.roll(power_spectrum, height // 2, dims=0)
    offset_spectrum = torch.roll(offset_spectrum, width // 2, dims=1)
    return offset_spectrum


def obtain_mask_from_radii(
    height: int, width: int, radius_inner: float, radius_outer: float
):
    # Calculate the loss using a mask for elements inside and outside the specified radii
    y, x = torch.meshgrid(torch.arange(height), torch.arange(width))
    radius = torch.sqrt((x - width // 2) ** 2 + (y - height // 2) ** 2)

    # Create masks for elements inside radius_inner and outside radius_outer
    inside_mask = radius <= radius_inner
    outside_mask = radius >= radius_outer

    return torch.logical_or(inside_mask, outside_mask)


def get_polygon_mask(height, width, num_sides, radius, thickness=1):
    # Create an empty tensor to draw the polygon on
    polygon_tensor = torch.zeros((height, width), dtype=torch.float32)

    # Calculate the center of the tensor
    center_x = width // 2
    center_y = height // 2

    # Calculate the angle between each vertex of the polygon
    angle = 2 * math.pi / num_sides

    # Calculate the coordinates of the polygon vertices
    vertices = []
    for i in range(num_sides):
        x = center_x + int(radius * math.cos(i * angle))
        y = center_y + int(radius * math.sin(i * angle))
        vertices.append((x, y))

    # Draw the edges of the polygon
    for i in range(num_sides):
        x1, y1 = vertices[i]
        x2, y2 = vertices[(i + 1) % num_sides]
        # Use Bresenham's line algorithm to draw the edge with specified thickness
        dx = abs(x2 - x1)
        dy = -abs(y2 - y1)
        sx = 1 if x1 < x2 else -1
        sy = 1 if y1 < y2 else -1
        err = dx + dy
        while True:
            # Set the pixel to 1.0 within the specified thickness
            for t in range(-thickness // 2, (thickness + 1) // 2):
                for u in range(-thickness // 2, (thickness + 1) // 2):
                    polygon_tensor[y1 + t, x1 + u] = 1.0
            if x1 == x2 and y1 == y2:
                break
            e2 = 2 * err
            if e2 >= dy:
                err += dy
                x1 += sx
            if e2 <= dx:
                err += dx
                y1 += sy

    return (1 - polygon_tensor).bool()


def ring_loss(
    cortical_sheet: TensorType["height", "width", "e"],
    radius_inner: float,
    radius_outer: float,
) -> TensorType:
    assert (
        radius_inner < radius_outer
    ), f"Expected radius_inner : {radius_inner} to be less than radius_outer: {radius_outer}"

    # power_spectrum.shape: height, width, e
    power_spectrum = get_power_spectrum(cortical_sheet=cortical_sheet)

    assert (
        power_spectrum.ndim == 3
    ), f"Expected power spectrum shape to have 3 dims (e, height, width) but got: {power_spectrum.ndim}"
    assert power_spectrum.shape == cortical_sheet.shape

    height, width, _ = power_spectrum.shape

    ## mask.shape: h, w
    mask = obtain_mask_from_radii(
        height=height,
        width=width,
        radius_inner=radius_inner,
        radius_outer=radius_outer,
    )
    assert mask.shape[0] == power_spectrum.shape[0]
    assert mask.shape[1] == power_spectrum.shape[1]

    power_spectrum = rearrange(power_spectrum, "h w e -> e h w")
    mask = mask.unsqueeze(0).to(power_spectrum.device)
    masked_power_spectrum = power_spectrum * mask.float()
    loss = masked_power_spectrum.mean()

    return loss


def polygon_loss(
    cortical_sheet: TensorType["e", "height", "width"],
    radius: int,
    thickness: int,
    num_sides: int,
) -> TensorType:
    assert (
        thickness < radius
    ), f"Expected thickness : {thickness} to be less than radius: {radius}"

    power_spectrum = get_power_spectrum(cortical_sheet=cortical_sheet)

    assert (
        power_spectrum.ndim == 3
    ), f"Expected power spectrum shape to have 3 dims (e, height, width) but got: {power_spectrum.ndim}"

    _, height, width = power_spectrum.shape

    mask = get_polygon_mask(
        height=height,
        width=width,
        num_sides=num_sides,
        radius=radius,
        thickness=thickness,
    ).to(cortical_sheet.device)

    # Calculate the loss as the sum of elements outside radius_outer and inside radius_inner
    loss = power_spectrum[:, mask].mean()

    return loss


def ring_loss_1d(cortical_sheet, freq_inner: int, freq_outer: int):
    assert (
        cortical_sheet.ndim == 2
    ), f"Expected cortical sheet to have 2 dims (num_input_neurons, e) but got: {cortical_sheet.ndim}"
    freq_domain_data = torch.fft.fft(cortical_sheet)
    center_shift = freq_domain_data.shape[-1] // 2
    freq_domain_data_rolled = torch.roll(freq_domain_data, shifts=center_shift, dims=-1)

    inside_ring = freq_domain_data_rolled[
        :, center_shift - freq_inner : center_shift + freq_inner
    ].abs()
    outside_ring_left = freq_domain_data_rolled[:, : center_shift - freq_outer].abs()
    outside_ring_right = freq_domain_data_rolled[:, center_shift + freq_outer :].abs()

    losses = []

    if inside_ring.numel() > 0:
        losses.append(inside_ring.mean())
    if outside_ring_left.numel() > 0:
        losses.append(outside_ring_left.mean())
    if outside_ring_right.numel() > 0:
        losses.append(outside_ring_right.mean())

    assert (
        len(losses) > 0
    ), f"It is very likely that you are passing invalid freq_inner or freq_outer values such that there is nothing outside of those bounds. Please either reduce freq_outer or increase freq_inner"
    return sum(losses)
