import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from functools import reduce
import operator


def riesz_transform_2d(x, dim=None, norm="ortho"):
    if dim is None:
        dim = (-2, -1)
    return torch.fft.rfftn(x, dim=dim, norm=norm)


def inverse_riesz_transform_2d(x_ft, s=None, dim=None, norm="ortho"):
    if dim is None:
        dim = (-2, -1)
    return torch.fft.irfftn(x_ft, s=s, dim=dim, norm=norm)


class Riesz2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, 
                 learnable_scale=True, scale_init='equal'):
        super(Riesz2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2
        self.learnable_scale = learnable_scale

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
        
        if learnable_scale:
            if scale_init == 'equal':
                init_value = 1.0 / 3.0
            elif scale_init == 'ones':
                init_value = 1.0
            elif scale_init == 'random':
                init_value = torch.rand(1).item()
            else:
                init_value = 1.0
            
            self.scale_x = nn.Parameter(torch.tensor(init_value, dtype=torch.float32))
            self.scale_r1 = nn.Parameter(torch.tensor(init_value, dtype=torch.float32))
            self.scale_r2 = nn.Parameter(torch.tensor(init_value, dtype=torch.float32))
        else:
            self.register_buffer('scale_x', torch.tensor(1.0))
            self.register_buffer('scale_r1', torch.tensor(1.0))
            self.register_buffer('scale_r2', torch.tensor(1.0))
        
        self.register_buffer('_freq_grids_initialized', torch.tensor(False))
        self.register_buffer('_local1', None)
        self.register_buffer('_local2', None)
        self.register_buffer('_spatial_size', torch.tensor([0, 0]))

    def _init_freq_grids(self, H, W, device):
        """Initialize and cache frequency grids for Riesz transform"""
        if (self._freq_grids_initialized and 
            self._spatial_size[0] == H and self._spatial_size[1] == W):
            return
        
        u = torch.fft.fftfreq(H, d=1.0/H, device=device)
        v = torch.fft.rfftfreq(W, d=1.0/W, device=device)
        
        U, V = torch.meshgrid(u, v, indexing="ij")
        mesh = torch.sqrt(U ** 2 + V ** 2)
        
        R1 = 1j * torch.where(mesh != 0, U / mesh, torch.zeros_like(U))
        R2 = 1j * torch.where(mesh != 0, V / mesh, torch.zeros_like(V))
        
        self._local1 = R1
        self._local2 = R2
        self._spatial_size[0] = H
        self._spatial_size[1] = W
        self._freq_grids_initialized = torch.tensor(True)

    

    def _apply_riesz_kernels(self, x_freq):

        directional_component_1 = torch.einsum("bixy, xy-> bixy", x_freq, self._local1)
        directional_component_2 = torch.einsum("bixy, xy-> bixy", x_freq, self._local2)
        return directional_component_1, directional_component_2
    
    def _combine_riesz_components(self, base_freq, dir1_freq, dir2_freq):

        combined = (self.scale_x * base_freq + 
                   self.scale_r1 * dir1_freq + 
                   self.scale_r2 * dir2_freq)
        return combined
    
    def _apply_spectral_convolution(self, riesz_combined, output_size):
        batch_size = riesz_combined.shape[0]
        freq_width = output_size // 2 + 1
        
        output_freq = torch.zeros(
            batch_size, self.out_channels, output_size, freq_width,
            device=riesz_combined.device, dtype=torch.cfloat
        )
        
        low_freq_slice = riesz_combined[:, :, :self.modes1, :self.modes2]
        output_freq[:, :, :self.modes1, :self.modes2] = \
            torch.einsum("bixy,ioxy->boxy", low_freq_slice, self.weights1)
        
        high_freq_slice = riesz_combined[:, :, -self.modes1:, :self.modes2]
        output_freq[:, :, -self.modes1:, :self.modes2] = \
            torch.einsum("bixy,ioxy->boxy", high_freq_slice, self.weights2)
        
        return output_freq

    def forward(self, x, size=None):
        if size is None:
            size = x.size(-1)

        _, _, H, W = x.shape
        self._init_freq_grids(H, W, x.device)
        
        frequency_representation = riesz_transform_2d(x, dim=[2, 3], norm="ortho")
        
        dir_component_1, dir_component_2 = self._apply_riesz_kernels(frequency_representation)
        
        riesz_enhanced = self._combine_riesz_components(
            frequency_representation, dir_component_1, dir_component_2
        )
        
        convolved_freq = self._apply_spectral_convolution(riesz_enhanced, size)
        
        spatial_output = inverse_riesz_transform_2d(
            convolved_freq, s=(size, size), dim=[2, 3], norm="ortho"
        )
        
        return spatial_output


class DynamicBlock2d(nn.Module):
    def __init__(self, in_dim, out_dim, domain_size, modes1, modes2, width):
        super(DynamicBlock2d, self).__init__()

        self.modes1 = modes1
        self.modes2 = modes2

        self.width_list = [width*2//4, width*3//4, width*4//4, width*4//4, width*5//4]
        self.size_list = [domain_size] * 5
        self.grid_dim = 2

        self.fc0 = nn.Linear(in_dim+self.grid_dim, self.width_list[0])

        self.conv0 = Riesz2d(self.width_list[0], self.width_list[1], self.modes1*4//4, self.modes2*4//4, 
                            learnable_scale=True, scale_init='equal')
        self.conv1 = Riesz2d(self.width_list[1], self.width_list[2], self.modes1*3//4, self.modes2*3//4,
                            learnable_scale=True, scale_init='equal')
        self.conv2 = Riesz2d(self.width_list[2], self.width_list[3], self.modes1*2//4, self.modes2*2//4,
                            learnable_scale=True, scale_init='equal')
        self.conv3 = Riesz2d(self.width_list[3], self.width_list[4], self.modes1*2//4, self.modes2*2//4,
                            learnable_scale=True, scale_init='equal')
        self.w0 = nn.Conv1d(self.width_list[0], self.width_list[1], 1)
        self.w1 = nn.Conv1d(self.width_list[1], self.width_list[2], 1)
        self.w2 = nn.Conv1d(self.width_list[2], self.width_list[3], 1)
        self.w3 = nn.Conv1d(self.width_list[3], self.width_list[4], 1)

        self.fc1 = nn.Linear(self.width_list[4], self.width_list[4]*2)
        self.fc2 = nn.Linear(self.width_list[4]*2, self.width_list[4]*2)
        self.fc3 = nn.Linear(self.width_list[4]*2, out_dim)

    def forward(self, x):

        batchsize = x.shape[0]
        size_x, size_y= x.shape[1], x.shape[2]
        grid = self.get_grid(size_x, batchsize, x.device)
        size_list = self.size_list

        x = torch.cat((x, grid.permute(0, 2, 3, 1)), dim=-1)

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)

        x1 = self.conv0(x, size_list[1])
        x2 = self.w0(x.view(batchsize, self.width_list[0], size_list[0]**2)).view(batchsize, self.width_list[1], size_list[0], size_list[0])
        x = x1 + x2
        x = F.selu(x) 

        x1 = self.conv1(x, size_list[2])
        x2 = self.w1(x.view(batchsize, self.width_list[1], size_list[1]**2)).view(batchsize, self.width_list[2], size_list[1], size_list[1])
        x = x1 + x2
        x = F.selu(x) 

        x1 = self.conv2(x, size_list[3])
        x2 = self.w2(x.view(batchsize, self.width_list[2], size_list[2]**2)).view(batchsize, self.width_list[3], size_list[2], size_list[2])
        x = x1 + x2
        x = F.selu(x)

        x1 = self.conv3(x, size_list[4])
        x2 = self.w3(x.view(batchsize, self.width_list[3], size_list[3]**2)).view(batchsize, self.width_list[4], size_list[3], size_list[3])
        x = x1 + x2

        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.selu(x)
        x = self.fc2(x)
        x = F.selu(x)
        x = self.fc3(x)
        return x

    def get_grid(self, S, batchsize, device):
        gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
        gridx = gridx.reshape(1, 1, S, 1).repeat([batchsize, 1, 1, S])
        gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
        gridy = gridy.reshape(1, 1, 1, S).repeat([batchsize, 1, S, 1])
        return torch.cat((gridx, gridy), dim=1).to(device)
    
class R2d(nn.Module):
    def __init__(self, in_dim, out_dim, domain_size, modes, width):
        super(R2d, self).__init__()
        self.conv1 = DynamicBlock2d(in_dim, out_dim, domain_size, modes, modes, width)

    def forward(self, x):
        x = self.conv1(x)
        return x

    def count_params(self):
        c = 0
        for p in self.parameters():
            c += reduce(operator.mul, list(p.size()))

        return c

