import torch
import torch.nn as nn

from ..utils.grid_size import find_rectangle_dimensions


def generate_cosine_wave_image(
    width: int, height: int, frequency_x: int, frequency_y: int
):
    x = torch.linspace(-1, 1, width)
    y = torch.linspace(-1, 1, height)

    grid_x, grid_y = torch.meshgrid(x, y)

    amplitude = 0.5  # Adjust the amplitude of the wave
    phase = 0.0  # Adjust the phase of the wave

    image = (
        amplitude
        * torch.cos(2 * torch.pi * frequency_x * grid_x + phase)
        * torch.cos(2 * torch.pi * frequency_y * grid_y + phase)
    )

    return image


def apply_grid_cosine_init_conv(layer: nn.Conv2d, alpha=1.0):
    assert isinstance(layer, nn.Conv2d), f"Expected a nn.Conv2d but got: {type(layer)}"
    weight = layer.weight.data
    output_size = weight.shape[0]
    input_size = weight.shape[1]

    kernel_size_1 = weight.shape[2]
    kernel_size_2 = weight.shape[3]
    rectangle_dims = find_rectangle_dimensions(area=output_size)

    frequency_x = rectangle_dims.width * alpha
    frequency_y = rectangle_dims.height * alpha

    cosine_grid = generate_cosine_wave_image(
        rectangle_dims.height, rectangle_dims.width, frequency_x, frequency_y
    ).to(layer.weight.data.device)

    weight = weight.reshape(
        rectangle_dims.height, rectangle_dims.width, -1
    ) * cosine_grid.unsqueeze(-1)
    weight = weight.reshape(output_size, input_size, kernel_size_1, kernel_size_2)
    ## set new weights for layer
    layer.weight.data = weight
    return layer


def apply_grid_cosine_init_linear(layer: nn.Linear, alpha=1.0):
    assert isinstance(layer, nn.Linear), f"Expected a nn.Linear but got: {type(layer)}"
    weight = layer.weight.data
    output_size = weight.shape[0]
    input_size = weight.shape[1]

    rectangle_dims = find_rectangle_dimensions(area=output_size)

    frequency_x = rectangle_dims.width * alpha
    frequency_y = rectangle_dims.height * alpha

    cosine_grid = generate_cosine_wave_image(
        rectangle_dims.height, rectangle_dims.width, frequency_x, frequency_y
    ).to(layer.weight.data.device)
    weight = weight.reshape(
        rectangle_dims.height, rectangle_dims.width, -1
    ) * cosine_grid.unsqueeze(-1)
    weight = weight.reshape(output_size, input_size)
    ## set new weights for layer
    layer.weight.data = weight
    return layer
