import torch
from torch.optim import Adam, SGD
import matplotlib.pyplot as plt
from scipy.optimize import minimize, LinearConstraint, NonlinearConstraint
import numpy as np
from einops import rearrange

class AntiAliasingLayer(torch.nn.Module):
    def __init__(self,
                 *,
                nodes: list, 
                adjaceny_matrix: np.ndarray, 
                basis: np.ndarray,
                subsample_nodes: list,
                subsample_adjacency_matrix: np.ndarray,
                sub_basis: np.ndarray,
                smooth_operator: str = "laplacian",
                optim=Adam,
                optim_kwargs={'lr':0.0001, 'weight_decay':0},
                smoothness_loss_weight=5.0,
                iterations=50000,
                device='cuda:0',
                dtype=torch.cfloat,
                graph_shift: np.ndarray = None,
                equi_raynold_op: np.ndarray = None,
                equi_constraint: bool = True,
                equi_correction: bool = True,
                mode: str ='optim',
                threshold=1e-3,
                over_smooth: bool = False,
                plot=True):
        super( ).__init__()
        self.nodes = nodes
        self.subsample_nodes = subsample_nodes
        
        self.device = device
        self.dtype = dtype
        self.iterations = iterations
        self.optim = optim
        self.mode = mode
        self.optim_kwargs = optim_kwargs
        self.smoothness_loss_weight = smoothness_loss_weight
        self.plot = plot
        self.over_smooth = over_smooth
        

        self.adjacency_matrix = torch.tensor(adjaceny_matrix) 
        self.smoother = None
        assert smooth_operator not in [ "laplacian", "normalized_laplacian"], "Laplacian not suitable for current implementation."

        if smooth_operator == "adjacency":
            self.smoother = self.adjacency_matrix / torch.sum(self.adjacency_matrix, dim=1, keepdim=True)
        elif smooth_operator == "laplacian":
            degree_matrix = torch.diag(torch.sum(self.adjacency_matrix, dim=1))
            self.smoother = degree_matrix - self.adjacency_matrix
        elif smooth_operator == "normalized_laplacian":
            degree_matrix = torch.diag(torch.sum(self.adjacency_matrix, dim=1))
            self.smoother = degree_matrix - self.adjacency_matrix
            degree_matrix_power = torch.sqrt(1.0 / degree_matrix)
            degree_matrix_power[degree_matrix_power == float('inf')] = 0 
            self.smoother = degree_matrix_power @ self.smoother @ degree_matrix_power
        elif smooth_operator == "graph_shift" and graph_shift is not None:
            self.smoother = torch.tensor(graph_shift)
        else:
            raise ValueError("Invalid smooth operator: ", smooth_operator)
        
        self.smoother = self.smoother.to(dtype)
        self.register_buffer("basis" ,torch.tensor(basis))

        self.equi_constraint = equi_constraint
        self.equi_raynold_op = None
        if equi_raynold_op is not None:
            self.equi_raynold_op = torch.tensor(equi_raynold_op).to(dtype)
        
        ev, evec = torch.linalg.eigh(self.equi_raynold_op)
        evec = evec[:, torch.abs(ev - 1) < 1e-3]
        self.register_buffer("equi_projector", evec @ np.linalg.inv(evec.T @ evec) @ evec.T)


        self.subsample_adjacency_matrix = torch.tensor(subsample_adjacency_matrix)  
        
        self.sampling_matrix = torch.zeros(len(self.subsample_nodes), len(self.nodes))
        for i, node in enumerate(self.subsample_nodes):
            self.sampling_matrix[i, node] = 1

        self.register_buffer("sub_basis", torch.tensor(sub_basis))
        
        self.register_buffer("M", self._calculate_M(optim=self.optim,
                                                    optim_kwargs=self.optim_kwargs,
                                                    iterations=self.iterations,
                                                    device='cpu',
                                                    dtype=self.dtype,
                                                    smoothness_loss_weight=self.smoothness_loss_weight,
                                                    mode=self.mode,
                                                    plot=self.plot))
        
        self.M[torch.abs(self.M) <= threshold] = 0
        self.M_bar = self.M @ torch.linalg.inv(self.M.T @ self.M) @ self.M.T
        eigvals, eigvecs = torch.linalg.eig(self.M_bar)
        eigvecs = eigvecs[:, torch.abs(eigvals - 1) < 1e-7]
        eigvecs[torch.abs(eigvecs) < 1e-6] = 0


        self.M.to(self.dtype)
        self.register_buffer("L1_eigs", eigvecs)
        self.register_buffer("L1_projector", torch.tensor(self.l1_projector(self.M.numpy())).to(self.dtype))
        self.register_buffer("up_sampling_basis", (self.basis * self.basis.shape[0]**0.5 @ self.M).to(self.dtype)/self.sub_basis.shape[0]**0.5)

        if equi_correction:
            self.L1_projector = self.equi_correction(self.L1_projector)

        self.basis = self.basis.to(self.dtype)
        self.sub_basis = self.sub_basis.to(self.dtype)
        self.sampling_matrix = self.sampling_matrix.to(self.dtype)
        


    def l1_projector(self, M):
        M_bar = M @ np.linalg.inv(M.T @ M) @ M.T
        eigvals, eigvecs = np.linalg.eig(M_bar)
        eigvecs = eigvecs[:, np.abs(eigvals - 1) < 1e-7]
        L1 = eigvecs @ np.linalg.pinv(eigvecs)
        return L1
    
    def equi_correction(self, operator):
        return (self.equi_projector @ operator.flatten()).reshape(operator.shape)

    def _calculate_M(self,
                    optim=Adam,
                    optim_kwargs={'lr':0.0001, 'weight_decay':0},
                    iterations=50000,
                    device='cuda:0', 
                    dtype=torch.cfloat,
                    smoothness_loss_weight=0.01,
                    mode='optim',
                    plot=True):
        
        if mode == 'optim':
            F = self.subsample_eigenvectors.clone().to(device, dtype=dtype)
            FB = self.eigenvectors.clone().to(device, dtype=dtype)
            S = self.sampling_matrix.clone().to(device, dtype=dtype)
            L = self.laplacian.clone().to(device, dtype=dtype)
            M = torch.randn((S.shape[1], F.shape[1]), requires_grad=True, dtype=dtype, device=device) 
            optimizer = optim([M], **optim_kwargs)

            print(F.shape, FB.shape, S.shape, L.shape, M.shape)

            print(" Inital Loss Reconstruction: ", torch.norm(F - S @ FB @ M, p=2))
            print(" Initial Smoothness by Laplacian loss :", torch.trace( (FB @ M).T @ L @ FB @ M))

            loss_list = []
            for i in range(iterations):
                B = FB @ M
                loss = torch.norm(F - S @ B, p=2)**2 + smoothness_loss_weight*torch.trace(B.T @ L @ B).real
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_list.append(loss.item())
            M = M.detach()
            print(" Final Loss Reconstruction :", torch.norm(F - S @ FB @ M, p=2))
            print(" Final Smoothness by Laplacian loss :", torch.trace((FB @ M).T @ L @ FB @ M).real)

            if plot:
                plt.plot(loss_list)
                # log scale along y-axis
                plt.yscale('log')
                plt.show()

            del F, FB, S, L
            return M.detach().cpu()
        elif mode == 'analytical':
            print("===Using Analytical Solution====")
            '''
            used analytical solution to solve the optimization problem
            '''
            M =  torch.linalg.pinv(self.sampling_matrix.to(self.basis.dtype)\
                                  @ self.basis) @ self.sub_basis
            
            return M

        elif mode == 'linear_optim':
            print("===Using Linear Optimization====")
            '''
            used alternating lease square programming to solve the optimization problem
            '''

            high_precision_dtype = None 
            if self.dtype == torch.cfloat:
                high_precision_dtype = torch.cfloat64
            elif self.dtype == torch.float:
                high_precision_dtype = torch.float64

            F =  self.sub_basis.clone().to(device, dtype=high_precision_dtype).numpy()   * self.sub_basis.shape[0]**0.5

            FB = self.basis.clone().to(device, dtype=high_precision_dtype).numpy()   * self.basis.shape[0]**0.5

            S = self.sampling_matrix.clone().to(device, dtype=high_precision_dtype).numpy()
            L =  self.smoother.to(device, dtype=high_precision_dtype).numpy()
            if self.equi_constraint:
                R = self.equi_raynold_op.clone().to(device, dtype=high_precision_dtype).numpy()
            else:
                R = None
            M = torch.zeros((S.shape[1], F.shape[1]), dtype=high_precision_dtype, device=device).numpy()

            F0 = np.zeros_like(F) + 1e-7

            def smoothness_ob1(M, FB, L):
                """
                Objective function to minimize: \sum |FB - L * FB|_2 x M
                """

                M = M.reshape(FB.shape[1], F.shape[1])
                shift_f = np.dot(L, FB)
                shift_err = np.diag(FB.T @ (FB - shift_f))
                shift_err = shift_err.reshape(-1, M.shape[0])

                return smoothness_loss_weight * np.mean(shift_err @ np.abs(M))
            
            
            def objective_function4(M, FB, L, R):
                """
                Objective function to minimize: |FB - L * FB . M|_2 + |R . L1 - L1|_2
                """
                smoothness_loss = smoothness_ob1(M, FB, L)

                M = M.reshape(FB.shape[1], F.shape[1])
                equi_error = equivarinace_loss(M, R)

                return smoothness_loss + 1 * equi_error
            
            def equivarinace_loss(M, R):
                """
                Objective function to minimize: |R . L1 - L1|_2

                Assumes M is already shaped into 2D.
                """
                L1 = self.l1_projector(M).reshape(-1)
                error = R @ L1 - L1
                return np.linalg.norm(error, ord=2)
            
            

            initial_guess_M = np.random.randn(*M.shape).flatten() 
            
            # Define the linear equality constraint
            linear_constraint_matrix = np.kron(S @ FB, np.eye(F.shape[1]))

            linear_constraint = LinearConstraint(linear_constraint_matrix, (F - F0).flatten(), (F + F0).flatten())


            # Set bounds to enforce strict equality constraint
            bounds = [(None, None)] * (FB.shape[1] * F.shape[1]) 

            if self.equi_constraint:
                result = minimize(objective_function4, initial_guess_M, args=(FB, L, R),
                    constraints=[linear_constraint], bounds=bounds, options={'maxiter': iterations})
            else:
                result = minimize(smoothness_ob1, initial_guess_M, args=(FB, L),
                    constraints=[linear_constraint], bounds=bounds, options={'maxiter': iterations})

            optimal_M = result.x

            M = optimal_M.reshape(FB.shape[1], F.shape[1])

            return torch.tensor(M, dtype=torch.float64)
        else:
            raise ValueError("Invalid mode: ", mode)


    def _forward_f_transform(self, x, basis):
        B = None 
        if self.dtype == torch.cfloat or self.dtype == torch.cdouble:
            B = torch.transpose(torch.conj(basis), 0, 1)
        elif self.dtype == torch.float or self.dtype == torch.float64:
            B = torch.transpose(basis, 0, 1)
        else:
            raise ValueError("Invalid dtype: ", self.dtype)

        if len(x.shape) == 1:
            return torch.matmul(B, x.to(basis.dtype))
        elif len(x.shape) == 5:
            return torch.einsum('fg,bcghw->bcfhw', B, x.to(basis.dtype))
        else:
            raise ValueError("Invalid shape: ", x.shape)

    def _inverse_f_transform(self, x, basis):
        if len(x.shape) == 1:
            return torch.matmul(basis, x)
        elif len(x.shape) == 5:
            return torch.einsum('fg,bcghw->bcfhw', basis, x.to(basis.dtype))
        else:
            raise ValueError("Invalid shape: ", x.shape)
    
    def fft(self, x):
        return self._forward_f_transform(x, self.basis)
    
    def ifft(self, x):
        return self._inverse_f_transform(x, self.basis)


    def anti_aliase(self, x):
        if len(x.shape) == 4:
            x = rearrange(x, 'b (c g) h w -> b c g h w', g=len(self.nodes))
        fh = self.fft(x)
        if len(x.shape) == 5:
            fh_p = torch.einsum('fg,bcghw->bcfhw', self.L1_projector, fh)
        else:
            fh_p = self.L1_projector @ fh
        x_out = self.ifft(fh_p)

        if len(x.shape) == 5:
            x_out = rearrange(x_out, 'b  c g h w -> b (c g) h w')
        return x_out
    
    def apply_subsample_matrix(self, x):
        return self.sampling_matrix @ x
    
    def up_sample(self, x):
        if len(x.shape) == 4:
            x = rearrange(x, 'b (c g) h w -> b c g h w', g=len(self.subsample_nodes))
        xh = self._forward_f_transform(x, self.sub_basis)
        x_upsampled = self._inverse_f_transform(xh, self.up_sampling_basis)
        if len(x.shape) == 5:
            x_upsampled = rearrange(x_upsampled, 'b c g h w -> b (c g) h w')
        return x_upsampled
    
    def forward(self, x):
        return self.anti_aliase(x)
