import re
import torch
from torch import Tensor
import numpy as np
import torch.nn as nn
import healpy as hp
import e3nn
from e3nn import o3

from src.utils import nearest_rotmat, rotation_error
from src.models import Encoder


def s2_near_identity_grid(max_beta=np.pi / 8, n_alpha=8, n_beta=3):
    """
    :return: rings around the north pole
    size of the kernel = n_alpha * n_beta
    """
    beta = torch.arange(1, n_beta + 1) * max_beta / n_beta
    alpha = torch.linspace(0, 2 * np.pi, n_alpha + 1)[:-1]
    a, b = torch.meshgrid(alpha, beta, indexing="ij")
    b = b.flatten()
    a = a.flatten()
    return torch.stack((a, b))

def so3_near_identity_grid(max_beta=np.pi / 8, max_gamma=2 * np.pi, n_alpha=8, n_beta=3, n_gamma=None):
    """
    :return: rings of rotations around the identity, all points (rotations) in
    a ring are at the same distance from the identity
    size of the kernel = n_alpha * n_beta * n_gamma
    """
    if n_gamma is None:
        n_gamma = n_alpha  # similar to regular representations
    beta = torch.arange(1, n_beta + 1) * max_beta / n_beta
    alpha = torch.linspace(0, 2 * np.pi, n_alpha)[:-1]
    pre_gamma = torch.linspace(-max_gamma, max_gamma, n_gamma)
    A, B, preC = torch.meshgrid(alpha, beta, pre_gamma, indexing="ij")
    C = preC - A
    A = A.flatten()
    B = B.flatten()
    C = C.flatten()
    return torch.stack((A, B, C))

def s2_healpix_grid(rec_level: int=0, max_beta: float=np.pi/6):
    """Returns healpix grid up to a max_beta
    """
    n_side = 2**rec_level
    npix = hp.nside2npix(n_side)
    m = hp.query_disc(nside=n_side, vec=(0,0,1), radius=max_beta)
    # print(f'nside: {nside} -> npix: {npix} -> n_in_disc: {len(m)}')
    beta, alpha = hp.pix2ang(n_side, m)
    alpha = torch.from_numpy(alpha)
    beta = torch.from_numpy(beta)
    return torch.stack((alpha, beta)).float()

def so3_healpix_grid(rec_level: int=3):
    """Returns healpix grid over so3
    https://github.com/google-research/google-research/blob/4808a726f4b126ea38d49cdd152a6bb5d42efdf0/implicit_pdf/models.py#L272

    alpha: 0-2pi around Y
    beta: 0-pi around X
    gamma: 0-2pi around Y

    rec_level | num_points | bin width (deg)
    ----------------------------------------
         0    |         72 |    60
         1    |        576 |    30
         2    |       4608 |    15
         3    |      36864 |    7.5
         4    |     294912 |    3.75
         5    |    2359296 |    1.875

    :return: tensor of shape (3,npix)
    """
    n_side = 2**rec_level
    npix = hp.nside2npix(n_side)
    beta, alpha = hp.pix2ang(n_side, torch.arange(npix))
    gamma = torch.linspace(0, 2*np.pi, 6*n_side + 1)[:-1]

    alpha = alpha.repeat(len(gamma))
    beta = beta.repeat(len(gamma))
    gamma = torch.repeat_interleave(gamma, npix)
    return torch.stack((alpha, beta, gamma)).float()


def s2_irreps(lmax):
    return o3.Irreps([(1, (l, 1)) for l in range(lmax + 1)])


def so3_irreps(lmax):
    return o3.Irreps([(2 * l + 1, (l, 1)) for l in range(lmax + 1)])


def flat_wigner(lmax, alpha, beta, gamma):
    return torch.cat([(2 * l + 1) ** 0.5 * o3.wigner_D(l, alpha, beta, gamma).flatten(-2) for l in range(lmax + 1)], dim=-1)


def rotate_s2(s2_signal, alpha=0, beta=0, gamma=0):
    '''alpha beta gamma in radians'''
    lmax = int(s2_signal.shape[-1]**0.5) - 1
    irreps = s2_irreps(lmax)
    alpha = torch.tensor(alpha, dtype=torch.float32)
    beta = torch.tensor(beta, dtype=torch.float32)
    gamma = torch.tensor(gamma, dtype=torch.float32)
    return torch.einsum("ij,...j->...i",
                        irreps.D_from_angles(alpha, beta, gamma),
                        s2_signal)

class S2Convolution(torch.nn.Module):
    def __init__(self, f_in, f_out, lmax, kernel_grid):
        super().__init__()
        self.register_parameter(
            "w", torch.nn.Parameter(torch.randn(f_in, f_out, kernel_grid.shape[1]))
        )  # [f_in, f_out, n_s2_pts]
        self.register_buffer(
            "Y", o3.spherical_harmonics_alpha_beta(range(lmax + 1), *kernel_grid, normalization="component")
        )  # [n_s2_pts, psi]
        self.lin = o3.Linear(s2_irreps(lmax), so3_irreps(lmax), f_in=f_in, f_out=f_out, internal_weights=False)

    def forward(self, x):
        psi = torch.einsum("ni,xyn->xyi", self.Y, self.w) / self.Y.shape[0] ** 0.5
        return self.lin(x, weight=psi)


class SO3Convolution(torch.nn.Module):
    def __init__(self, f_in, f_out, lmax, kernel_grid):
        super().__init__()
        self.register_parameter(
            "w", torch.nn.Parameter(torch.randn(f_in, f_out, kernel_grid.shape[1]))
        )  # [f_in, f_out, n_so3_pts]
        self.register_buffer("D", flat_wigner(lmax, *kernel_grid))  # [n_so3_pts, psi]
        self.lin = o3.Linear(so3_irreps(lmax), so3_irreps(lmax), f_in=f_in, f_out=f_out, internal_weights=False)

    def forward(self, x):
        psi = torch.einsum("ni,xyn->xyi", self.D, self.w) / self.D.shape[0] ** 0.5
        return self.lin(x, weight=psi)


class HarmonicS2Features(nn.Module):
    def __init__(self, sphere_fdim, lmax, f_out=1):
        super().__init__()
        self.fdim = sphere_fdim
        self.lmax = lmax

        # (f_in, f_out, (lmax+1)**2)
        weight = torch.zeros((self.fdim, f_out, (lmax+1)**2), dtype=torch.float32)
        self.weight = nn.Parameter(data=weight, requires_grad=True)
        torch.nn.init.kaiming_uniform_(self.weight)

        # (f_out, f_in)
        bias = torch.zeros((self.fdim, f_out), dtype=torch.float32)
        self.bias = nn.Parameter(data=bias, requires_grad=True)
        torch.nn.init.kaiming_uniform_(self.bias)

    def forward(self):
        return self.weight, self.bias

    def __repr__(self):
        return f'HarmonicS2Features(fdim={self.fdim}, lmax={self.lmax})'


class SpatialS2Features(nn.Module):
    def __init__(self, sphere_fdim, lmax, rec_level=1, f_out=1):
        super().__init__()
        self.fdim = sphere_fdim
        self.lmax = lmax

        alpha, beta = s2_healpix_grid(max_beta=np.inf, rec_level=rec_level)
        self.register_buffer(
            "Y", o3.spherical_harmonics_alpha_beta(range(lmax+1), alpha, beta, normalization='component')
        )

        # (f_in, f_out, (lmax+1)**2)
        weight = torch.zeros((self.fdim, f_out, alpha.shape[0]), dtype=torch.float32)
        self.weight = nn.Parameter(data=weight, requires_grad=True)
        torch.nn.init.kaiming_uniform_(self.weight)

        bias = torch.zeros((self.fdim, f_out), dtype=torch.float32)
        self.bias = nn.Parameter(data=bias, requires_grad=True)
        torch.nn.init.kaiming_uniform_(self.bias)

    def forward(self):
        x = torch.einsum("ni,xyn->xyi", self.Y, self.weight) / self.Y.shape[0]**0.5
        return x.unsqueeze(1), self.bias


class HarmonicS2Projector(nn.Module):
    def __init__(self,
                 input_shape: tuple,
                 sphere_fdim: int,
                 lmax: int,
                ):
        super().__init__()
        self.sphere_fdim = sphere_fdim
        self.n_harmonics = (lmax+1)**2

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.lin1 = nn.Linear(input_shape[0], 512)
        self.lin2 = nn.Linear(512, sphere_fdim*self.n_harmonics)

        self.norm_act = e3nn.nn.S2Activation(s2_irreps(lmax), torch.relu, 20)

    def forward(self, x: Tensor) -> Tensor:
        x = torch.flatten(self.avg_pool(x), 1)
        x = self.lin1(x)
        x = torch.relu(x)
        x = self.lin2(x).view(x.size(0), self.sphere_fdim, self.n_harmonics)
        x = self.norm_act(x)
        return x


class SpatialS2Projector(nn.Module):
    def __init__(self,
                 input_shape: tuple,
                 sphere_fdim: int,
                 lmax: int,
                 coverage: float=0.9,
                 sigma: float=0.2,
                 n_subset: int=20,
                 max_beta: float=np.radians(90),
                 rec_level: int=2,
                 taper_beta: float=np.radians(75),
                ):
        '''Project from feature map to spherical signal

        ToDo:
            subsample grid points in train mode (might require a buffer)
            add noise to grid location in train mode
        '''
        super().__init__()
        self.n_subset = n_subset

        fmap_size = input_shape[1]

        self.conv1x1 = nn.Conv2d(input_shape[0], sphere_fdim*1, 1)

        # north pole is at y=+1
        self.kernel_grid = s2_healpix_grid(max_beta=max_beta, rec_level=rec_level)

        self.xyz = o3.angles_to_xyz(*self.kernel_grid)

        # orthographic projection
        max_radius = torch.linalg.norm(self.xyz[:,[0,2]], dim=1).max()
        sample_x = coverage * self.xyz[:,2] / max_radius # range -1 to 1
        sample_y = coverage * self.xyz[:,0] / max_radius

        gridx, gridy = torch.meshgrid(2*[torch.linspace(-1,1,fmap_size)], indexing='ij')
        scale = 1 / np.sqrt(2 * np.pi * sigma**2)
        data = scale * torch.exp(-((gridx.unsqueeze(-1) - sample_x).pow(2) \
                                   +(gridy.unsqueeze(-1) - sample_y).pow(2)) / (2*sigma**2) )
        data = data / data.sum((0,1), keepdims=True)

        # apply mask to taper magnitude near border if desired
        betas = self.kernel_grid[1]
        if taper_beta < max_beta:
            mask = ((betas - max_beta)/(taper_beta - max_beta)).clamp(max=1).view(1,1,-1)
        else:
            mask = torch.ones_like(data)

        data = (mask * data).unsqueeze(0).unsqueeze(0).to(torch.float32)
        self.weight = nn.Parameter(data= data, requires_grad=True)

        self.n_pts = self.weight.shape[-1]
        self.ind = torch.arange(self.n_pts)

        self.register_buffer(
            "Y", o3.spherical_harmonics_alpha_beta(range(lmax+1),
                                                   *self.kernel_grid,
                                                   normalization='component')
        )

    def forward(self, x: Tensor) -> Tensor:
        '''
        :x: float tensor of shape (B,C,H,W)
        :return: feature vector of shape (B,P,C) where P is number of points on S2
        '''
        x = self.conv1x1(x)

        if self.n_subset is not None:
            self.ind = torch.randperm(self.n_pts)[:self.n_subset]

        x = (x.unsqueeze(-1) * self.weight[..., self.ind]).sum((2,3))
        x = torch.relu(x)
        x = torch.einsum('ni,xyn->xyi', self.Y[self.ind], x) / self.ind.shape[0]**0.5
        return x


class BaseSO3Predictor(nn.Module):
    def __init__(self,
                 num_classes: int=1,
                 encoder: str='resnet18',
                 pool_features: bool=False,
                 **kwargs
                ):
        super().__init__()
        self.num_classes = num_classes

        if encoder.find('equiv') > -1:
            self.encoder = EqEncoder(pool_features=pool_features)
            if encoder.find('pretrained') > -1:
                self.encoder.load_state_dict(torch.load('resnet_equiv_60.pt'))
        else:
            pretrained = encoder.find('pretrained') > -1
            size = int(re.findall('\d+', encoder)[0])
            self.encoder = Encoder(size, pretrained, pool_features)


    def save(self, path):
        torch.save(self.state_dict(), path)


class I2S(BaseSO3Predictor):
    def __init__(self,
                 num_classes: int=1,
                 sphere_fdim: int=512,
                 encoder: str='resnet50_pretrained',
                 projection_mode='spatialS2',
                 feature_sphere_mode='harmonicS2',
                 lmax: int=6,
                 f_hidden: int=8,
                 train_rec_level: int=3,
                 eval_rec_level: int=5,
                ):
        super().__init__(num_classes, encoder, pool_features=False)

        #projection stuff
        self.projector = {
            'spatialS2' : SpatialS2Projector,
            'harmonicS2' : HarmonicS2Projector,
        }[projection_mode](self.encoder.output_shape, sphere_fdim, lmax)

		#spherical conv stuff
        self.feature_sphere = {
            'spatialS2' : SpatialS2Features,
            'harmonicS2' : HarmonicS2Features,
        }[feature_sphere_mode](sphere_fdim, lmax, f_out=f_hidden)

        irreps_in = s2_irreps(lmax)
        self.o3_conv = o3.Linear(irreps_in, so3_irreps(lmax),
                                 f_in=sphere_fdim, f_out=f_hidden, internal_weights=False)

        self.so3_activation = e3nn.nn.SO3Activation(lmax, lmax, torch.relu, 10)
        so3_grid = so3_near_identity_grid()
        self.so3_conv = SO3Convolution(f_hidden, 1, lmax, so3_grid)

        # output rotations for training and evaluation
        self.train_rec_level = train_rec_level
        output_xyx = so3_healpix_grid(rec_level=train_rec_level)
        self.register_buffer(
            "output_wigners", flat_wigner(lmax, *output_xyx).transpose(0,1)
        )
        self.register_buffer(
            "output_rotmats", o3.angles_to_matrix(*output_xyx)
        )

        self.eval_rec_level = eval_rec_level
        output_xyx = so3_healpix_grid(rec_level=eval_rec_level)

        self.eval_wigners = flat_wigner(lmax, *output_xyx).transpose(0,1)
        self.eval_rotmats = o3.angles_to_matrix(*output_xyx)

    def forward(self, x, o, return_harmonics=False):
        x = self.encoder(x)
        x = self.projector(x)

        weight, _ = self.feature_sphere()
        x = self.o3_conv(x, weight=weight)

        x = self.so3_activation(x)

        x = self.so3_conv(x)

        if return_harmonics:
            return x

        return torch.matmul(x, self.output_wigners).squeeze(1)

    @torch.no_grad()
    def predict(self, x, o, k=1):
        # k is for topk analysis
        probs = self.compute_probabilities(x, o)

        # pred_id = grid_signal.max(dim=1)[1]
        pred_ids = torch.topk(probs, k, dim=1)[1].transpose(0,1)

        if k == 1:
            return self.eval_rotmats[pred_ids[0]]

        predictions = []
        for i in range(k):
            predictions.append(self.eval_rotmats[pred_ids[i]])

        return torch.stack(predictions, dim=0)

    @torch.no_grad()
    def compute_probabilities(self, x, o):
        harmonics = self.forward(x, o, return_harmonics=True)

        # move to cpu to avoid memory issues, at expense of speed
        harmonics = harmonics.cpu()

        probs = torch.matmul(harmonics, self.eval_wigners).squeeze(1)

        return nn.Softmax(dim=1)(probs)

    def compute_loss(self, img, cls, rot):
        grid_signal = self.forward(img, cls)
        rot_id = nearest_rotmat(rot, self.output_rotmats)

        loss = nn.CrossEntropyLoss()(grid_signal, rot_id)

        with torch.no_grad():
            pred_id = grid_signal.max(dim=1)[1]
            pred_rotmat = self.output_rotmats[pred_id]
            acc = rotation_error(rot, pred_rotmat, 'angle')
            loss_info = dict(
                cls_loss = loss.item(),
                acc = acc.cpu().numpy(),
            )

        return loss, loss_info
