import torch
from torch import nn
from .spherical_harmonics_ylm import SH

def SH_(args):
    return SH(*args)

class DiscretizedSphericalHarmonics(nn.Module):
    def __init__(self, legendre_polys: int = 10):
        """
        legendre_polys: determines the number of legendre polynomials.
                        more polynomials lead more fine-grained resolutions
        embedding_dims: determines the dimension of the embedding.
        """
        super(DiscretizedSphericalHarmonics, self).__init__()
        self.L, self.M = int(legendre_polys), int(legendre_polys)
        self.embedding_dim = self.L * self.M

        lon = torch.tensor(torch.linspace(-180, 180, 360))
        lat = torch.tensor(torch.linspace(-90, 90, 180))
        lons, lats = torch.meshgrid(lon, lat)

        # ij indexing to xy indexing
        lons, lats = lons.T, lats.T

        phi = torch.deg2rad(lons + 180)
        theta = torch.deg2rad(lats + 90)

        Ys = []
        for l in range(self.L):
            for m in range(-l, l + 1):
                Ys.append(SH(m, l, phi, theta) * torch.ones_like(phi))

        self.Ys = torch.stack(Ys)
        self.Ys = self.Ys.permute(0, 2, 1)

    def forward(self, lonlat):

        lonlat = lonlat + torch.tensor([180,90], device=lonlat.device)

        Ys = interpolate_pixel_values(self.Ys.to(lonlat.device), lonlat).T
        return Ys

    def get_coeffs(self, l, m):
        """
        convenience function to store two triangle matrices in one where m can be negative
        """
        if m == 0:
            return self.weight[l, 0]
        if m > 0:  # on diagnoal and right of it
            return self.weight[l, m]
        if m < 0:  # left of diagonal
            return self.weight[-l, m]

    def get_weight_matrix(self):
        """
        a convenience function to restructure the weight matrix (L x M x E) into
        a double triangle matrix (L x 2 * L + 1 x E) where with legrende polynomials
        are on the rows and frequency components -m ... m on the columns.
        """
        unfolded_coeffs = torch.zeros(self.L, self.L * 2 + 1, self.E, device=self.weight.device)
        for l in range(0, self.L):
            for m in range(-l, l + 1):
                unfolded_coeffs[l, m + self.L] = self.get_coeffs(l, m)
        return unfolded_coeffs

def interpolate_pixel_values(image, points):
    num_points = len(points)
    rows, cols = image.size()[1], image.size()[2]

    # Convert sub-pixel coordinates to integer indices
    floor_coords = torch.floor(points).long()
    ceil_coords = torch.ceil(points).long()

    # Compute fractional parts for interpolation weights
    frac_coords = points - floor_coords.float()

    # Clamp the indices to ensure they are within image boundaries
    floor_coords[:, 0] = torch.clamp(floor_coords[:, 0], 0, rows - 1)
    floor_coords[:, 1] = torch.clamp(floor_coords[:, 1], 0, cols - 1)
    ceil_coords[:, 0] = torch.clamp(ceil_coords[:, 0], 0, rows - 1)
    ceil_coords[:, 1] = torch.clamp(ceil_coords[:, 1], 0, cols - 1)

    # Extract pixel values from the image
    floor_pixels = image[:, floor_coords[:, 0], floor_coords[:, 1]]
    ceil_pixels = image[:, ceil_coords[:, 0], ceil_coords[:, 1]]

    # Compute interpolation weights
    weights_floor = (1 - frac_coords[:, 0]) * (1 - frac_coords[:, 1])
    weights_ceil = frac_coords[:, 0] * (1 - frac_coords[:, 1])
    weights = torch.stack([weights_floor, weights_ceil], dim=1)

    # Interpolate pixel values
    interpolated_pixels = torch.sum(torch.stack([floor_pixels, ceil_pixels], dim=2) * weights.view(1, num_points, 2), dim=2)

    return interpolated_pixels
