"""
Contains the CLOMP class for mixture estimation using CLOMP algorithm.
"""
from typing import NoReturn, Literal

import time
from abc import abstractmethod
import torch
import scipy
from pdfo import pdfo
from loguru import logger
import torch.nn.functional as f
from pycle.compressive_learning.SolverTorch_G import SolverTorch_G
from pycle.utils.intermediate_storage import ObjectiveValuesStorage
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import math

import random

######Changes #####
from numpy.core.numerictypes import typecodes
######Changes #####



class hypercube():
    def __init__(self,a,b,d):
        self.a = a
        self.b = b
        self.dim = d

    def sample(self):
        return torch.from_numpy(np.random.uniform(self.a,self.b,self.dim))
    def project(self,x):
        y = np.zeros(self.dim)
        for i in list(range(self.dim)):
            if x[i]<self.a:
                y[i] = self.a
            elif x[i]>self.b:
                y[i] = self.b
            else:
                y[i] = x[i]
        return torch.from_numpy(y)
    def grid(self,N):
        tmp_list = [self.sample() for n in list(range(N))]
        return tmp_list
    
class ball():
    def __init__(self,a,R,d):
        self.a = a
        self.R = R
        self.dim = d

    def sample(self):
        u = np.random.normal(0,1,self.dim+2)  # an array of (d+2) normally distributed random variables
        norm=np.sum(u**2) **(0.5)
        u = u/norm
        x = self.a+self.R*u[0:self.dim] #take the first d coordinates
        #print(np.linalg.norm(x))
        return torch.from_numpy(x)
    # def sample(self):
    #     u = np.random.normal(0,1,self.dim+2)  # an array of (d+2) normally distributed random variables
    #     norm=np.sum(u**2) **(0.5)
    #     u = u/norm
    #     x = self.a+self.R*u[0:self.dim] #take the first d coordinates
    #     #print(np.linalg.norm(x))
    #     return torch.from_numpy(random.uniform(0,1)*x/np.linalg.norm(x))
    # def sample(self):
    #     u = np.random.normal(0,1,self.dim)  # an array of (d+2) normally distributed random variables
    #     norm=np.sum(u**2) **(0.5)
    #     u = u/norm
    #     x = self.a+self.R*u[0:self.dim] #take the first d coordinates
    #     #print(np.linalg.norm(x))
    #     return torch.from_numpy(x/np.linalg.norm(x))
    
    
    
    def project(self,x):
        #print('yes')
        u = (1/np.linalg.norm(x))*x
        y = self.R*u
        #print(self.R)
        return torch.from_numpy(y.numpy())
    def grid(self,N):
        tmp_list = [self.sample() for n in list(range(N))]
        return tmp_list
    
    



def plot_this_2D_function_and_nodes(function,n,pos_list,a,b,M):


    #funct = function_list[0]
    # Generate x and y values
    x = np.linspace(a, b, M)
    y = np.linspace(a, b, M)
    X, Y = np.meshgrid(x, y)  # Create a grid of x and y values

    # Calculate corresponding z values using the function
    Z = np.zeros((M,M))
    for m_1 in list(range(M)):
        #print(m_1)
        for m_2 in list(range(M)):
            theta = torch.from_numpy(np.array((X[m_1,m_2],Y[m_1,m_2])))
            #print(theta)
            z = function(theta)
            #-funct.evaluate(theta)
            #funct.evaluate(theta)
            #function(theta)
            #
            Z[m_1,m_2] = z
            #print(Z[m_1,m_2])
    #Z = function((X, Y)
    
    plt.figure(figsize=(10, 8))
    #Z[Z<=0] = -5
    plt.imshow(Z, extent=[-1, 1, -1, 1], origin='lower')
    #plt.show()
    plt.colorbar(label='Z')
    
    
    x_coords = [theta[0] for theta in pos_list]
    y_coords = [theta[1] for theta in pos_list]
    #plt.scatter(x_coords, y_coords, marker='o', color='r', s=20)
    
    plt.scatter([n[0]], [n[1]], marker='o', color='black', s=20)
    
    #print(x_coords)
    plt.plot(x_coords, y_coords, marker='o', linestyle='-', color='red')
    

    plt.xlabel('X')
    plt.ylabel('Y')
    #plt.title('3D Heatmap Plot of the Three-Dimensional Function')
    plt.show()
    
    
def plot_this_2D_log_function_and_nodes(function,n,pos_list,a,b,M):


    #funct = function_list[0]
    # Generate x and y values
    x = np.linspace(a, b, M)
    y = np.linspace(a, b, M)
    X, Y = np.meshgrid(x, y)  # Create a grid of x and y values

    # Calculate corresponding z values using the function
    Z = np.zeros((M,M))
    for m_1 in list(range(M)):
        #print(m_1)
        for m_2 in list(range(M)):
            theta = torch.from_numpy(np.array((X[m_1,m_2],Y[m_1,m_2])))
            #print(theta)(
            z = -np.sign(function(theta))*np.log(np.abs(function(theta)))
            #-funct.evaluate(theta)
            #funct.evaluate(theta)
            #function(theta)
            #
            Z[m_1,m_2] = z
            #print(Z[m_1,m_2])
    #Z = function((X, Y)
    
    plt.figure(figsize=(10, 8))
    #Z[Z<=0] = -5
    plt.imshow(Z, extent=[-1, 1, -1, 1], origin='lower')
    #plt.show()
    plt.colorbar()
    
    x_coords = [theta[0] for theta in pos_list]
    y_coords = [theta[1] for theta in pos_list]
    #plt.scatter(x_coords, y_coords, marker='o', color='r', s=20)
    
    plt.scatter([n[0]], [n[1]], marker='o', color='blue', s=20)
    
    #print(x_coords)
    plt.plot(x_coords, y_coords, marker='o', linestyle='-', color='red')
    

    plt.xlabel('X')
    plt.ylabel('Y')
    #plt.title('3D Heatmap Plot of the Three-Dimensional Function')
    plt.show()

    
def plot_this_2D_functions_and_trajectories(function,pos_list,a,b,M):


    #funct = function_list[0]
    # Generate x and y values
    x = np.linspace(a, b, M)
    y = np.linspace(a, b, M)
    X, Y = np.meshgrid(x, y)  # Create a grid of x and y values

    # Calculate corresponding z values using the function
    Z = np.zeros((M,M))
    for m_1 in list(range(M)):
        #print(m_1)
        for m_2 in list(range(M)):
            theta = torch.from_numpy(np.array((X[m_1,m_2],Y[m_1,m_2])))
            #print(theta)
            z = function(theta)
            #-funct.evaluate(theta)
            #funct.evaluate(theta)
            #function(theta)
            #
            Z[m_1,m_2] = z
            #print(Z[m_1,m_2])
    #Z = function((X, Y)
    
    plt.figure(figsize=(10, 8))
    #Z[Z<=0] = -5
    plt.imshow(Z, extent=[-1, 1, -1, 1], origin='lower')
    #plt.show()
    
    
    x_coords = [theta[0] for theta in pos_list]
    y_coords = [theta[1] for theta in pos_list]
    plt.plot(x_coords, y_coords, marker='o', linestyle='-', color='r')
    #print(x_coords)
    if len(pos_list)>0:
        plt.plot([pos_list[0][0]], [pos_list[0][1]], marker='o', linestyle='-', color='b')
    #plt.plot(x_coords_2, y_coords_2, marker='o', linestyle='-', color='b')
    plt.colorbar(label='Z')

    plt.xlabel('X')
    plt.ylabel('Y')
    #plt.title('3D Heatmap Plot of the Three-Dimensional Function')
    plt.show()

class D_OMP_G(SolverTorch_G):
    """
    Implementation of the CLOMP algorithm to fit the sketch of a mixture model to the sketch z of a distribution.

    CLOMP is an instance of the class :class:`pycle.compressive_learning.SolverTorch.SolverTorch`.

    This class can use gradient descent through `torch` to find the components of the mixture model
    or it can use derivative free optimization through the `pdfo` library.
    Derivative free optimization is slower and it doesn't support too high dimension (>100) but doesn't need the
    feature map to be derivable.

    To create a subclass inheriting from CLOMP algorithm, some methods must be overriden:

    - `randomly_initialize_several_mixture_components(self, int)` to define how to initialize a given number of mixture components.
    - `sketch_of_mixture_components(self, (KxD) tensor )` to define how to get the feature map of a single or K mixture components
    - `set_bounds_thetas(bounds)` to define the bounding box where to look for the mixture components.

    To have a better understanding, look at the code of the class :class:`pycle.compressive_learning.CLOMP_CKM.CLOMP_CKM`

    References
    ----------
    Keriven, N., Bourrier, A., Gribonval, R., & Pérez, P. (2018). Sketching for large-scale learning of mixture models.
    Information and Inference: A Journal of the IMA, 7(3), 447-508.
    https://arxiv.org/pdf/1606.02838.pdf
    """
    # cleaning make dynamic reference to SOlverTorch rst

    LST_OPT_METHODS_TORCH = ["adam", "lbfgs"]

    def __init__(self, *args, **kwargs):
        self.maxiter_inner_optimizations = None
        self.tol_inner_optimizations = None
        self.nb_iter_max_step_5 = None
        self.nb_iter_max_step_1 = None
        self.opt_method_step_1 = None
        self.opt_method_step_34 = None
        self.opt_method_step_5 = None
        self.lr_inner_optimizations = None

        super().__init__(*args, **kwargs)

        self.weight_lower_bound = 1e-9
        self.weight_upper_bound = 2

        self.count_iter_torch_find_optimal_weights = 0
        self.count_iter_torch_maximize_atom_correlation = 0
        self.count_iter_torch_minimize_cost_from_current_sol = 0
        self.s_sigma = self.dct_optim_method_hyperparameters.get("s_sigma")
        self.grid_mode = 'static_grid'

    def initialize_hyperparameters_optimization(self) -> None:
        """
        Transform optimization parameters in dct_optim_method_hyperparameters to actual attributes of the object.

        Default key values for the dct_optim_method_hyperparameters dictionnary are::

            {
                "maxiter_inner_optimizations": 15000,  # Max number of iterations for all torch optimizations
                "tol_inner_optimizations": 1e-9,  # Change tolerance before stopping iterating in all torch optimizations
                "nb_iter_max_step_5": 200, # Max number of iterations for PDFO in step 5 (global finetuning)
                "nb_iter_max_step_1": 200, # Max number of iterations for PDFO in step 1 (find new cluster center)
                "opt_method_step_1": "lbfgs", # Default optimization algorithm for step 1 (find new cluster center)
                "opt_method_step_34": "nnls", # Default optimization algorithm for step 3 and 4 (find best mixture weights)
                "opt_method_step_5": "lbfgs", # Default optimization algorithm for step 5 (global finetuning)
                "lr_inner_optimizations": 1  # Start learning rate for torch optimizations with Adam.
            }

        """
        self.maxiter_inner_optimizations = self.dct_optim_method_hyperparameters.get("maxiter_inner_optimizations", 100)
        self.tol_inner_optimizations = self.dct_optim_method_hyperparameters.get("tol_inner_optimizations", 1e-9)
        self.nb_iter_max_step_5 = self.dct_optim_method_hyperparameters.get("nb_iter_max_step_5", 200)
        self.nb_iter_max_step_1 = self.dct_optim_method_hyperparameters.get("nb_iter_max_step_1", 200)
        self.opt_method_step_1 = self.dct_optim_method_hyperparameters.get("opt_method_step_1", "lbfgs")
        self.opt_method_step_34 = self.dct_optim_method_hyperparameters.get("opt_method_step_34", "nnls")
        self.opt_method_step_5 = self.dct_optim_method_hyperparameters.get("opt_method_step_5", "lbfgs")
        self.lr_inner_optimizations = self.dct_optim_method_hyperparameters.get("lr_inner_optimizations", 1)
        self.s_sigma = self.dct_optim_method_hyperparameters.get("s_sigma")
        self.grid_mode = self.dct_optim_method_hyperparameters.get("grid_mode")
        #print(self.dct_optim_method_hyperparameters.get("grid_mode"))
        #print(self.grid_mode)
        
        
    def plot_sign_hessian_function(self,function,a,b,M):


        #funct = function_list[0]
        # Generate x and y values
        x = np.linspace(a, b, M)
        y = np.linspace(a, b, M)
        X, Y = np.meshgrid(x, y)  # Create a grid of x and y values

        # Calculate corresponding z values using the function
        Z = np.zeros((M,M))
        counter = 0
        for m_1 in list(range(M)):
            #print(m_1)
            for m_2 in list(range(M)):
                theta = torch.from_numpy(np.array((X[m_1,m_2],Y[m_1,m_2])))
                #print(theta)
                sigma = torch.from_numpy(self.evaluate_sigma_in_node(function,theta,0.0000001))
                eigvals = torch.linalg.eigvalsh(sigma)
                #print(sigma)
                #print(theta)
                #print(eigvals)
                if torch.isnan(torch.sum(eigvals)) ==True:
                    z = 0
                elif torch.min(eigvals)<=0:
                    counter = counter +1
                    z = -1
                else:
                    z= 1

                #-funct.evaluate(theta)
                #funct.evaluate(theta)
                #function(theta)
                #
                Z[m_1,m_2] = z
                #print(Z[m_1,m_2])
        #Z = function((X, Y)
        print(counter)
        plt.figure(figsize=(10, 8))
        #Z[Z<=0] = -5
        plt.imshow(Z, extent=[-1, 1, -1, 1], origin='lower')
        #plt.show()
        plt.colorbar(label='Z')
        


        plt.xlabel('X')
        plt.ylabel('Y')
        #plt.title('3D Heatmap Plot of the Three-Dimensional Function')
        plt.show()
        
    def centroids_recovery_ms(self,function,grad_function,alpha,beta,T,dom,mode):
        d = dom.dim
        recovered_n_list = []
        recovered_w_list = []
        recovered_sigma_list = []
        #print(recovered_n_list)
        #print(recovered_w_list)
        rejection_counter_list =[]
        repetition_counter_list = []
        x_list_lists = []

            
        #recovered_function = copy.deepcopy(kernelized_mixture_dirac(recovered_n_list,recovered_w_list,recovered_sigma_list,s))
        #recovered_function_list.append(recovered_function)

        r_count = 0

        count = 0
        R=1

        x_0 = dom.sample()
        #np.random.uniform(-1,1,d)
        count = 0
        while function(x_0)<=0:
            x_0 = dom.sample()
            if count>99:
                break
            count= count+1
            
        rejection_counter_list.append(count)

            
        #x_0 = dom.sample()
        #np.random.uniform(-1,1,d)
        # if i ==0:
        #     plot_this_2D_functions_and_trajectories(function,[recovered_function],[],-1,1,50)
        #     old_rec_func =recovered_function
        # if i>0:
        #     plot_this_2D_functions_and_trajectories(function,[old_rec_func],x_list_lists[i-1],-1,1,50)
        #     old_rec_func = recovered_function
        count = 0
        

        x_list = [x_0]

        x_tmp = x_0
        #r_count = 0
        for t in list(range(T)):
            #print(t)
            if function(x_tmp)<=0:
                x_tmp_g = Variable(x_tmp,requires_grad = True)
                x = x_tmp-alpha*(1/(function(x_tmp)))*(grad_function(x_tmp_g))

                x_tmp = dom.project(x)
                x_list.append(x_tmp)



            else:
                x_tmp_g = Variable(x_tmp,requires_grad = True)
                x = x_tmp+alpha*(1/(function(x_tmp)))*(grad_function(x_tmp_g))

                x_tmp = dom.project(x)
                x_list.append(x_tmp)
                #g_tmp = 1/(function.evaluate_using_sketch(x_tmp)-recovered_function.evaluate(x_tmp))*(function.gradient_using_sketch(x_tmp)-recovered_function.gradient(x_tmp))

        repetition_counter_list.append(r_count)
        f_T = len(x_list)
        if f_T >0:
            
            n = x_list[f_T-1]
            w = function(n)
            s = 1
            if mode =='dirac':
                #print('aieeeee')
                sigma = s*np.eye(d)
                #np.power(s,2)*np.eye(d)
            elif mode =='gaussian':
                sigma = evaluate_sigma_in_node(function.evaluate_using_sketch,n,0.00001)
            #evaluate_sigma_in_node(function.evaluate_using_sketch,n,0.00001)
            x_list_lists.append(x_list)
            recovered_n_list.append(n)
            recovered_w_list.append(w)
            recovered_sigma_list.append(sigma)

    #plot_this_2D_functions_and_trajectories(function,[recovered_function_list[0]],recovered_n_list,-1,1,50)

        return recovered_n_list,recovered_w_list,recovered_sigma_list,x_list_lists

    def evaluate_sigma_in_node(self,function,node,epsilon):
        d = node.shape[0]
        e_list = []
        for k in list(range(d)):
            e = np.zeros(d)
            e[k]=1
            e_list.append(e)
        #print(e_list)
        sigma = np.zeros((d,d))
        for i in list(range(d)):
            for j in list(range(d)):
                node_plus = node+epsilon*(e_list[i]+e_list[j])
                node_plus_2 = node-epsilon*(e_list[i]+e_list[j])
                node_minus = node+epsilon*(e_list[i]-e_list[j])
                node_minus_2 = node-epsilon*(e_list[i]-e_list[j])
                
                sigma_plus = (np.log(function(node_plus))+np.log(function(node_plus_2))-2*np.log(function(node)))/(np.power(epsilon,2))
                sigma_minus = (np.log(function(node_minus))+np.log(function(node_minus_2))-2*np.log(function(node)))/(np.power(epsilon,2))
                sigma[i,j] = (sigma_plus-sigma_minus)/4
        
        return -2*np.linalg.inv(sigma)


    def centroids_recovery___(self,function,N,s,dom,mode,grid):
        d = dom.dim
        n_grid = len(grid)
        
        recovered_n_list = []
        recovered_w_list = []
        recovered_sigma_list = []
        #print(recovered_n_list)
        #print(recovered_w_list)
        rejection_counter_list =[]
        repetition_counter_list = []
        x_list_lists = []
        recovered_function_list = []
        for i in list(range(N)):
            x_list = []
            recovered_function = copy.deepcopy(kernelized_mixture_dirac(recovered_n_list,recovered_w_list,recovered_sigma_list,s))
            recovered_function_list.append(recovered_function)
            #print(i)
            r_count = 0
            R = 1
            max_f = function(grid[0])-recovered_function.evaluate(grid[0])
            argmax_j = 0
            for j in list(range(n_grid)):
                
                if function(grid[j])-recovered_function.evaluate(grid[j])>max_f:
                    max_f = function(grid[j])-recovered_function.evaluate(grid[j])
                    argmax_j = j
                    #print(max_f)
                    #print(j)
            n = grid[argmax_j] 
            w = function(n)
            if mode =='dirac':
                sigma = s*np.eye(d)
                
            elif mode =='gaussian':
                sigma = evaluate_sigma_in_node(function,n,0.00001)
            x_list = [n]
            x_list_lists.append(x_list)
            recovered_n_list.append(n)
            recovered_w_list.append(w)
            recovered_sigma_list.append(sigma)
            #plot_this_2D_functions_and_trajectories(function,[recovered_function],[n],-1,1,50)

    def centroids_recovery_grid(self,function,s,d,mode,grid):
        n_grid = len(grid)
        max_f = function(grid[0])
        argmax_j = 0
        for j in list(range(n_grid)):
            
            if function(grid[j])>max_f:
                max_f = function(grid[j])
                argmax_j = j
                #print(max_f)
                #print(j)
        n = grid[argmax_j] 
        if mode =='dirac':
            sigma = s*np.eye(d)
            
        elif mode =='gaussian':
            sigma = self.evaluate_sigma_in_node(function,n,0.00001)

        return n,sigma





    def get_correlation_function(self):
        def correlation_function_aux(theta):
            #result = torch.real(torch.vdot(self.sketch_of_mixture_components(theta), self.residual))
            #sketch_theta = self.sketch_of_mixture_components(theta)
            #print(result.shape)
            #print('aa')
            #print(sketch_theta.shape)
            return torch.real(torch.vdot(self.sketch_of_mixture_components(theta), self.residual))
        return correlation_function_aux
    
    
    def get_correlation_function_0(self):
        def correlation_function_0_aux(theta):
            #result = torch.real(torch.vdot(self.sketch_of_mixture_components(theta), self.residual))
            #sketch_theta = self.sketch_of_mixture_components(theta)
            #print(result.shape)
            #print('aa')
            #print(sketch_theta.shape)
            return torch.real(torch.vdot(self.sketch_of_mixture_components(theta), self.sketch_reweighted))
        return correlation_function_0_aux
    
    
    def get_iterative_function(self):
        if self.current_solution==None:
            z = 0
        else:
            z = self.sketch_of_solution(alphas=self.current_solution[1],thetas=self.current_solution[0],
                                                                         sigmas =self.current_solution[2])
        def iterative_function_aux(theta):
            #result = torch.real(torch.vdot(self.sketch_of_mixture_components(theta), self.residual))
            #sketch_theta = self.sketch_of_mixture_components(theta)
            #print(result.shape)
            #print('aa')
            #print(sketch_theta.shape)
            if self.current_solution==None:
                return 0
            else:
                return torch.real(torch.vdot(self.sketch_of_mixture_components(theta), z))
        return iterative_function_aux
    
    def get_gradient_correlation_function(self):
        f = self.get_correlation_function()
        def gradient_correlation_function_aux(theta):
            y = f(theta)
            y.backward()
            dtheta = theta.grad
            #print(dtheta)
            return dtheta
        return gradient_correlation_function_aux
    def find_sigmas(self,s,d,function):
        tmp_list = []
        for i in list(range(self.current_size_mixture)):
            sigma_ = self.evaluate_sigma_in_node(function,self.thetas[i,:],0.0000001)
            #print(sigma_)
            if math.isnan(np.linalg.det(sigma_))==True:
                sigma = torch.from_numpy(s*np.eye(d))
            else:
                sigma = torch.from_numpy(sigma_)
            tmp_list.append(sigma.double())
        return tmp_list 


    @abstractmethod
    def projection_step(self, thetas):
        """
        Project mixture component parameters vector theta (or a set of thetas) on the constraint specifed
        by self.centroid_project of class `Projector`.

        The modification is made in place.

        Parameters
        ----------
        thetas
            (D,) or (current_size_mixture, D)-shaped tensor containing the parameters vector to project.
        """
        raise NotImplementedError

    def sketch_of_solution(self, alphas, thetas, sigmas, phi_thetas=None):
        """
        Returns the sketch of the solution, A_Phi(thetas, alphas) = sum_k^K {alpha_k \* phi_theta_k}.

        Parameters
        ----------
        alphas
            (current_size_mixture,) shaped tensor of the weights of the mixture in the solution.
        thetas
            (current_size_mixture, D) shaped tensor containing the component parameters of the solution.
        phi_thetas
            (M, current_size_mixture) shaped tensor of each component sketched. If None, then the sketch will be computed
            from alphas and thetas.

        Returns
        -------
            (M,)-shaped tensor containing the sketch of the mixture
        """
        assert thetas is not None or phi_thetas is not None
        #print(thetas)
        #print(alphas)
        if phi_thetas is None:
            #print(thetas.size()[-1])
            #sketch_thetas = self.sketch_of_mixture_components(thetas)
            #print(sketch_thetas.shape)
            #print(thetas[0,-1:])
            phi_thetas = torch.transpose(self.sketch_of_mixture_components_with_sigmas(thetas,sigmas), 0, 1) 
        return torch.matmul(phi_thetas.to(self.comp_dtype), alphas.to(self.comp_dtype))

    def add_atom___(self, new_theta) -> NoReturn:
        """
        Adding a new theta and the corresponding new phi_theta to the CLOMP object.

        This will be used in each iteration when the new theta has been found (end of step 1 of the algorithm).

        Parameters
        ----------
        new_theta:
            (D, )- shaped tensor containing a new mixture component to add to the solution.

        References
        ----------
        Keriven, N., Bourrier, A., Gribonval, R., & Pérez, P. (2018). Sketching for large-scale learning of mixture models.
        Information and Inference: A Journal of the IMA, 7(3), 447-508.
        https://arxiv.org/pdf/1606.02838.pdf

        """
        self.current_size_mixture += 1
        self.thetas = torch.cat((self.thetas, torch.unsqueeze(new_theta, 0)), dim=0)
        sketch_atom = self.sketch_of_mixture_components(new_theta)
        self.phi_thetas = torch.cat((self.phi_thetas, torch.unsqueeze(sketch_atom, 1)), dim=1)
   
        
   
    
    def add_atom(self, new_theta) -> NoReturn:
        """
        Adding a new theta and the corresponding new phi_theta to the CLOMP object.

        This will be used in each iteration when the new theta has been found (end of step 1 of the algorithm).

        Parameters
        ----------
        new_theta:
            (D, )- shaped tensor containing a new mixture component to add to the solution.

        References
        ----------
        Keriven, N., Bourrier, A., Gribonval, R., & Pérez, P. (2018). Sketching for large-scale learning of mixture models.
        Information and Inference: A Journal of the IMA, 7(3), 447-508.
        https://arxiv.org/pdf/1606.02838.pdf

        """
        s = self.dct_optim_method_hyperparameters.get("s_sigma")
        D = self.thetas_dimension_D
        self.current_size_mixture += 1
        self.thetas = torch.cat((self.thetas, torch.unsqueeze(new_theta, 0)), dim=0)
        f_z_0 = self.get_correlation_function_0()
        
        sigma_ = self.evaluate_sigma_in_node(f_z_0,new_theta,0.0000001)
        #print(sigma_)
        #print(sigma_)
        if math.isnan(np.linalg.det(sigma_))==True:
            #print('naaaaan')
            sigma = torch.from_numpy(s*np.eye(D))
        else:
            sigma = torch.from_numpy(sigma_)

        
        sketch_atom = torch.squeeze(self.sketch_of_mixture_components_with_sigmas(new_theta.reshape((1,D)),[sigma]))
        #print(sketch_atom.shape)
        #print(self.phi_thetas.shape)
        self.phi_thetas = torch.cat((self.phi_thetas, torch.unsqueeze(sketch_atom, 1)), dim=1)   
        
        
        

    def remove_one_component(self, index_to_remove) -> NoReturn:
        """
        Remove a theta and the corresponding phi_theta.

        Removing a component should happen during step 3 of the algorithm.

        Parameters
        ----------
        index_to_remove:
            The index of the component to remove. The one with the smallest coefficient in the mixture.

        References
        ----------
        Keriven, N., Bourrier, A., Gribonval, R., & Pérez, P. (2018). Sketching for large-scale learning of mixture models.
        Information and Inference: A Journal of the IMA, 7(3), 447-508.
        https://arxiv.org/pdf/1606.02838.pdf
        """
        self.current_size_mixture -= 1
        self.thetas = torch.cat((self.thetas[:index_to_remove], self.thetas[index_to_remove+1:]), dim=0)
        self.alphas = torch.cat((self.alphas[:index_to_remove], self.alphas[index_to_remove+1:]), dim=0)
        del self.sigmas[index_to_remove]
        self.phi_thetas = torch.cat((self.phi_thetas[:, :index_to_remove], self.phi_thetas[:, index_to_remove + 1:]),
                                    dim=1)

    def loss_atom_correlation(self, theta):
        """
        Compute the correlation between sketch of the input theta and the residual of the current solution.

        This is the objective function of the step 1 of the algorithm.

        Parameters
        ----------
        theta
            (D,) -shaped tensor containg the current location where to estimate the correlation with the residual.

        References
        ----------
        Keriven, N., Bourrier, A., Gribonval, R., & Pérez, P. (2018). Sketching for large-scale learning of mixture models.
        Information and Inference: A Journal of the IMA, 7(3), 447-508.
        https://arxiv.org/pdf/1606.02838.pdf

        Returns
        -------
            The value of the objective function evaluated at theta.
        """
        sketch_of_atom = self.sketch_of_mixture_components(theta)
        norm_atom = torch.norm(sketch_of_atom)
        # Trick to avoid division by zero (doesn't change anything because everything will be zero)
        if norm_atom.item() < self.minimum_phi_theta_norm:
            norm_atom = torch.tensor(self.minimum_phi_theta_norm).to(self.device)

        # note the "minus 1" that transforms the problem into a minimization problem
        result = -1. / norm_atom * torch.real(torch.vdot(sketch_of_atom, self.residual))
        if self.store_objective_values:
            ObjectiveValuesStorage().add(float(result), "loss_atom_correlation")
        return result

    def find_optimal_weights(self, normalize_phi_thetas=False, prefix="") -> torch.Tensor:
        """
        Returns the optimal wheights for the current mixture components 
        by solving the Non-Negative Least Square problem.

        This correspond to the third and fourth subproblem of the CLOMPR algorithm.

        This function uses scipy.optimize.nnls or torch depending on the parameters of CLOMP object.

        Parameters
        ----------
        normalize_phi_thetas
            Tells to normalize the atoms before fitting the weights. This is usefull to recover weight illustrating the
            importance of each atom in the mixture.
        prefix
            Prefix the identifier of the list of objective values, if they are stored.

        References
        ----------
        Keriven, N., Bourrier, A., Gribonval, R., & Pérez, P. (2018). Sketching for large-scale learning of mixture models.
        Information and Inference: A Journal of the IMA, 7(3), 447-508.
        https://arxiv.org/pdf/1606.02838.pdf

        Returns
        -------
            (k,) shaped tensor of the weights of the least square solution
        """

        init_alphas = torch.zeros(self.current_size_mixture, device=self.device)

        all_atoms = self.phi_thetas
        #print()
        #print(all_atoms)
        if self.opt_method_step_34 == "nnls":
            output_find_optimal_weights_nnls = self._find_optimal_weights_nnls(all_atoms, normalize_phi_thetas=normalize_phi_thetas)
            if output_find_optimal_weights_nnls=='error_Az':
                return 'error_Az'
            else:
                return output_find_optimal_weights_nnls
        elif self.opt_method_step_34 in self.LST_OPT_METHODS_TORCH:
            return self._find_optimal_weights_torch(init_alphas, all_atoms, prefix, normalize_phi_thetas=normalize_phi_thetas)
        else:
            raise ValueError(f"Unkown optimization method: {self.opt_method_step_34}")
    
    #herehere
    #herehere

    def _find_optimal_weights_nnls(self, phi_thetas, normalize_phi_thetas=False) -> torch.Tensor:
        """
        Returns the optimal weights for the input phi_thetas by solving the Non-Negative Least Square problem.

        This correspond to the third and fourth subproblem of the CLOMPR algorithm.

        This function uses scipy.optimize.nnls procedure to solve the problem. This means that
        the input tensor and the residual have to be cast to numpy object, hence inducing some latency
        if the tensors were stored on GPU.

        Parameters
        ----------
        phi_thetas
            (M, current_size_mixture)-shaped tensor containing the sketch of all components in the mixture.
        normalize_phi_thetas
            Tells to normalize the sketch of the components before fitting the weights.
            This is usefull to recover the weights illustrating the importance of each component in the mixture.

        References
        ----------
        - Keriven, N., Bourrier, A., Gribonval, R., & Pérez, P. (2018). Sketching for large-scale learning of mixture models.
        Information and Inference: A Journal of the IMA, 7(3), 447-508.
        https://arxiv.org/pdf/1606.02838.pdf
        - scipy.optimize.nnls: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.nnls.html

        Returns
        -------
            (k,) shaped tensor of the weights of the least square solution
        """
        # scipy.optimize.nnls uses numpy arrays as input
        phi_thetas = phi_thetas.cpu().numpy()

        # Stack real and imaginary parts if necessary
        if np.any(np.iscomplex(phi_thetas)):  # True if complex sketch output
            _A = np.r_[phi_thetas.real, phi_thetas.imag]
            _z = np.r_[self.sketch_reweighted.real.cpu().numpy(), self.sketch_reweighted.imag.cpu().numpy()]
        else:
            _A = phi_thetas
            _z = self.sketch_reweighted.cpu().numpy()

        if normalize_phi_thetas:
            norms = np.linalg.norm(phi_thetas, axis=0)
            norm_too_small = np.where(norms < self.minimum_phi_theta_norm)[0]
            if norm_too_small.size > 0:  # Avoid division by zero
                logger.debug(f'norm of some atoms is too small (min. {norms.min()}), changed to {self.minimum_phi_theta_norm}.')
                norms[norm_too_small] = self.minimum_phi_theta_norm
            _A = _A / norms

        # Use non-negative least squares to find optimal weights
        #### My changes #####
        if _A.dtype.char in typecodes['AllFloat'] and not np.isfinite(_A).all():
            return 'error_Az'
        elif _z.dtype.char in typecodes['AllFloat'] and not np.isfinite(_z).all():
            return 'error_Az'
        else:
            (_alpha, _) = scipy.optimize.nnls(_A, _z)

        return torch.from_numpy(_alpha).to(self.device)

    def _find_optimal_weights_torch(self, init_alphas, phi_thetas, prefix, normalize_phi_thetas=False) -> torch.Tensor:
        """
        Returns the optimal wheights for the input phi_thetas by solving the Non-Negative Least Square problem.

        This correspond to the third and fourth subproblem of the CLOMPR algorithm.

        This function uses torch to solve the problem.

        Parameters
        ----------
        init_alphas
            (current_size_mixture,)-tensor corresponding to the initial weights used for the optimization.
        phi_thetas
            (M, current_size_mixture)-shaped tensor containing the sketch of all components in the solution.
        prefix
            Prefix the identifier of the list of objective values, if they are stored.
        normalize_phi_thetas
            Tells to normalize the phi_thetas before fitting the weights.
            This is usefull to recover weight illustrating the importance of each component in the mixture.

        References
        ----------
        - Keriven, N., Bourrier, A., Gribonval, R., & Pérez, P. (2018). Sketching for large-scale learning of mixture models.
        Information and Inference: A Journal of the IMA, 7(3), 447-508.
        https://arxiv.org/pdf/1606.02838.pdf

        Returns
        -------
            (k,) shaped tensor of the weights of the least square solution
        """
        if normalize_phi_thetas:
            phi_thetas = f.normalize(self.phi_thetas, dim=1, eps=self.minimum_phi_theta_norm)
        else:
            phi_thetas = self.phi_thetas

        log_alphas = torch.nn.Parameter(init_alphas, requires_grad=True)
        optimizer = self._initialize_optimizer(self.opt_method_step_34, [log_alphas])

        def closure():
            optimizer.zero_grad()

            loss = self.loss_global(phi_thetas=phi_thetas, alphas=torch.exp(log_alphas).to(self.real_dtype))

            if self.store_objective_values:
                ObjectiveValuesStorage().add(float(loss), "{}/find_optimal_weights".format(prefix))
            if self.tensorboard:
                self.writer.add_scalar(self.path_template_tensorboard_writer.format('step3-4'), loss.item(), i)

            loss.backward()
            return loss

        for i in range(self.maxiter_inner_optimizations):
            self.count_iter_torch_find_optimal_weights += 1
            if self.opt_method_step_34 == "lbfgs":
                loss = optimizer.step(closure)  # bfgs takes the loss computation function as argument at each step.
            else:
                loss = closure()
                optimizer.step()

            if i != 0:
                relative_loss_diff = torch.abs(previous_loss - loss) / torch.abs(previous_loss)

                if relative_loss_diff.item() <= self.tol_inner_optimizations:
                    break

            previous_loss = torch.clone(loss)

        if self.tensorboard:
            self.writer.flush()
            self.writer.close()

        # this exp trick allows to keep the alphas positive by design
        alphas = torch.exp(log_alphas)
        normalized_alphas = alphas / torch.sum(alphas)

        return normalized_alphas.detach()

    def loss_global(self, alphas, all_thetas=None, phi_thetas=None):
        """
        Objective function of the global optimization problem: fitting the moments to the sketch of the mixture.

        Parameters
        ----------
        alphas
            (current_size_mixture,) shaped tensor of the weights of the mixture in the solution.
        all_thetas
            (current_size_mixture, D) shaped tensor containing the centers of the solution.
        phi_thetas
            (M, current_size_mixture) shaped tensor of each center sketched. If None, then the sketch will be computed from alphas
            and thetas.

        Returns
        -------
            The value of the objective evaluated at the provided parameters.
        """
        assert all_thetas is not None or phi_thetas is not None, "Thetas and Atoms must not be both None"
        sketch_solution = self.sketch_of_solution(alphas, all_thetas,sigmas, phi_thetas=phi_thetas)
        loss = torch.linalg.norm(self.sketch_reweighted - sketch_solution) ** 2
        return loss

    def _stack_sol(self, alphas: np.ndarray = None, thetas: np.ndarray = None) -> np.ndarray:
        """
        Stacks *all* the atoms and their weights into one vector.

        Note that this only work for numpy.ndarray.

        This is useful for optimization function taking only one vector as input.

        Parameters
        ----------
        alphas:
            (current_size_mixture,) shaped ndarray of the weights of the mixture in the solution.
        thetas:
            (current_size_mixture, D) shaped ndarray containing the centers of the solution.

        Returns
        -------

        (current_size_mixture + (current_size_mixture*D), ) shaped ndarray of alls alphas and Theta flattened and stacked together.
        """
        if (thetas is not None) and (alphas is not None):
            _Theta, _alpha = thetas, alphas
        else:
            _Theta, _alpha = self.thetas, self.alphas

        return np.r_[_Theta.reshape(-1), _alpha]

    def _destack_sol(self, p):
        """
        Reverse operation of `self._stack_sol`. Get back the parameters for 1D vector.

        Note that this only work for numpy.ndarray.

        This is useful for optimization function taking only one vector as input.

        Parameters
        ----------
        p
            (current_size_mixture + (current_size_mixture*D), ) shaped ndarray of alls alphas and thetas
            flattened and stacked together.

        Returns
        -------
            (alphas, thetas) tuple:
                - alphas: (current_size_mixture,) shaped ndarray of the weights of the mixture in the solution.
                - thetas: (current_size_mixture, D) shaped ndarray containing the parameters of the solution components.
        """
        assert p.shape[-1] == self.current_size_mixture * (self.thetas_dimension_D + 1)
        if p.ndim == 1 or p.shape[0] == 1:
            p = p.squeeze()
            thetas = p[:self.thetas_dimension_D * self.current_size_mixture].reshape(self.current_size_mixture, self.thetas_dimension_D)
            alphas = p[-self.current_size_mixture:].reshape(self.current_size_mixture)
        else:
            # todo fix?
            raise NotImplementedError(f"Impossible to destack p of shape {p.shape}.")
            # thetas = p[:, :self.d_theta * self.current_size_mixture].reshape(-1, self.current_size_mixture,
            # self.d_theta)
            # alphas = p[:, -self.current_size_mixture:].reshape(-1, self.current_size_mixture)
        return alphas, thetas

    def minimize_cost_from_current_sol(self, prefix="") -> NoReturn:
        """
        Minimise the global cost by tuning the whole solution (weights and centers).

        The method doesn't return anything because the values are updated in place.

        Step 5 in CLOMP-R algorithm.

        Parameters
        ----------
        prefix
            Prefix the identifier of the list of objective values, if they are stored.
        """
        # Parameters, optimizer
        all_thetas = self.thetas

        if self.opt_method_step_5 == "pdfo":

            # log_alphas = torch.log(self.alphas)
            # return self._minimize_cost_from_current_sol_torch(log_alphas, thetas, prefix)
            return self._minimize_cost_from_current_sol_pdfo(self.alphas, all_thetas, prefix)
        elif self.opt_method_step_5 in self.LST_OPT_METHODS_TORCH:
            if type(self.alphas=='str'):
                return 'error_Az'
            else:
                log_alphas = torch.log(self.alphas)
                return self._minimize_cost_from_current_sol_torch(log_alphas, all_thetas, prefix)
        else:
            raise ValueError(f"Unkown optimization method: {self.opt_method_step_5}")

    def _minimize_cost_from_current_sol_pdfo(self, alphas, thetas, prefix):
        """
        Minimise the global cost by tuning the whole solution (weights and centers).

        The method doesn't return anything because the values are updated in place.

        Step 5 in CLOMP-R algorithm.

        This method uses pdfo subroutine for derivative free optimization which involves conversion of parameters
        to numpy.ndarray. Also, because the subroutine only takes vectors as input,
        it needs the solution to be stacked together as a single vector.

        Parameters
        ----------
        alphas:
            (current_size_mixture,) shaped tensor of the weights of the mixture in the solution.
        thetas:
            (current_size_mixture, D) shaped tensor containing the parameters of the solution components.
        prefix
            Prefix the identifier of the list of objective values, if they are stored.

        References
        ----------
        - pdfo software: https://www.pdfo.net/index.html

        """
        def wrapped_loss_global(stacked_x):
            # the global loss doesn't take stacked parameters as input so it must be destacked first
            (_alpha, _Theta) = self._destack_sol(stacked_x)
            result = float(self.loss_global(all_thetas=torch.from_numpy(_Theta), alphas=torch.from_numpy(_alpha)))
            if self.store_objective_values:
                ObjectiveValuesStorage().add(float(result), f"minimize_cost_from_current_sol_pdfo/{prefix}")
            return result

        stacked_x_init = self._stack_sol(alphas=alphas.cpu().numpy(), thetas=thetas.cpu().numpy())
        bounds_Theta_alpha = self.bounds_atom * self.current_size_mixture + [[self.weight_lower_bound, self.weight_upper_bound]] * self.current_size_mixture
        # fct_fun_grad = self.get_global_cost
        fct_fun_grad = wrapped_loss_global
        sol = pdfo(fct_fun_grad,
                   x0=stacked_x_init,  # Start at current solution
                   bounds=bounds_Theta_alpha,
                   options={'maxfev': self.nb_iter_max_step_5 * stacked_x_init.size,
                            # 'rhoend': ftol
                            }
                   )

        (_alphas, _all_thetas) = self._destack_sol(sol.x)

        self.thetas = torch.Tensor(_all_thetas).to(self.real_dtype).to(self.device)
        self.alphas = torch.Tensor(_alphas).to(self.real_dtype).to(self.device)

    def _minimize_cost_from_current_sol_torch(self, log_alphas, thetas, prefix) -> NoReturn:
        """
        Minimise the global cost by tuning the whole solution (weights and centers).

        The method doesn't return anything because the values are updated in place.

        Step 5 in CLOMP-R algorithm.

        This method uses torch optimization.

        Parameters
        ----------
        log_alphas:
            (current_size_mixture,) shaped tensor of the logs of the weights of the mixture in the solution. We take the log
            because it allows to constraint the weights to be positive (by taking the exp)
        thetas:
            (current_size_mixture, D) shaped tensor containing the parameters of the solution components.
        prefix
            Prefix the identifier of the list of objective values, if they are stored.
        """
        # Parameters, optimizer
        log_alphas = log_alphas.requires_grad_()
        thetas = thetas.requires_grad_()
        params = [log_alphas, thetas]

        optimizer = self._initialize_optimizer(self.opt_method_step_5, params)

        def closure():
            optimizer.zero_grad()

            loss = self.loss_global(all_thetas=thetas, alphas=torch.exp(log_alphas).to(self.real_dtype))

            if self.tensorboard:
                self.writer.add_scalar(self.path_template_tensorboard_writer.format('step5'), loss.item(), iteration)
            if self.store_objective_values:
                ObjectiveValuesStorage().add(float(loss), f"{prefix}/minimize_cost_from_current_sol")

            loss.backward()
            return loss

        for iteration in range(self.maxiter_inner_optimizations):
            self.count_iter_torch_minimize_cost_from_current_sol += 1
            if self.opt_method_step_5 == "lbfgs":
                loss = optimizer.step(closure)  # bfgs takes the loss computation function as argument at each step.
            else:
                loss = closure()
                optimizer.step()

            # Projection step
            with torch.no_grad():
                self.projection_step(thetas)

            # Tracking loss
            if iteration != 0:
                relative_loss_diff = torch.abs(previous_loss - loss) / previous_loss

                if relative_loss_diff.item() < self.tol_inner_optimizations:
                    break

            previous_loss = torch.clone(loss)

        if self.tensorboard:
            self.writer.flush()
            self.writer.close()

        self.thetas = thetas.detach().to(self.device)
        self.alphas = torch.exp(log_alphas).detach().to(self.device)

    def do_step_4_5(self, prefix=""):
        """
        Do step 4 and 5 of CLOMP-R algorithm.

        Parameters
        ----------
        prefix
            Prefix the identifier of the list of objective values, if they are stored.
        """
        # Step 4: project to find weights
        since = time.time()
        self.alphas = self.find_optimal_weights(prefix=f"{prefix}b")
        logger.debug(f'Time for step 4: {time.time() - since}')
        # Step 5: fine-tune
        since = time.time()
        self.minimize_cost_from_current_sol(prefix=prefix)
        logger.debug(f'Time for step 5: {time.time() - since}')
        # The atoms have changed: we must re-compute their sketches matrix
        self.update_current_solution_and_cost(new_current_solution=(self.thetas, self.alphas))

    def final_fine_tuning(self) -> NoReturn:
        """
        Minimise the global cost by tuning the whole solution (weights and centers).

        This is used as the last finetuning step in CLOMP.
        """
        logger.info(f'Final fine-tuning...')
        self.minimize_cost_from_current_sol(prefix="final")
        # self.projection_step(self.thetas) # this is useless given the projection was made in the last method
        logger.debug(torch.norm(self.thetas[:, :self.phi.d], dim=1))
        self.update_current_solution_and_cost(new_current_solution=(self.thetas, self.alphas))

    def _initialize_optimizer(self, opt_method: Literal["adam", "lbfgs"], params: list):
        """
        Create an optimizer object according to the `opt_method`.

        Parameters
        ----------
        opt_method:
            Name of the optimization method.
        params:
            List of torch parameters to track and update by the optimizer.

        Returns
        -------
            The torch optimizer object.
        """
        if opt_method == "adam":
            optimizer = torch.optim.Adam(params, lr=self.lr_inner_optimizations)
        elif opt_method == "lbfgs":
            # learning rate is kept to 1 in that case because the rate is decided by strong_wolfe condition
            optimizer = torch.optim.LBFGS(params, max_iter=1, line_search_fn="strong_wolfe")
        else:
            raise ValueError(f"Optimizer {opt_method} cannot be used for gradient descent. "
                             f"Choose one in {self.LST_OPT_METHODS_TORCH}")
        return optimizer

    def maximize_atom_correlation(self, prefix=""):
        """
        Step 1 in CLOMP-R algorithm. Find the theta giving the most correlated atom to the residual.

        Optimization can use pdfo or torch.

        Parameters
        ----------
        prefix
            Prefix the identifier of the list of objective values, if they are stored.

        Returns
        -------
            (D,)-shaped tensor corresponding to the new center.
        """
        new_theta = self.randomly_initialize_several_mixture_components(1).squeeze(0)

        if self.opt_method_step_1 == "pdfo":
            return self._maximize_atom_correlation_pdfo(new_theta)
            # return self._maximize_atom_correlation_torch(new_theta, prefix)
        elif self.opt_method_step_1 in self.LST_OPT_METHODS_TORCH:
            return self._maximize_atom_correlation_torch(new_theta, prefix)
        else:
            raise ValueError(f"Unkown optimization method: {self.opt_method_step_1}")

    def _maximize_atom_correlation_pdfo(self, new_theta: torch.Tensor) -> torch.Tensor:
        """
        Step 1 in CLOMP-R algorithm. Find the theta giving the most correlated atom to the residual.

        This method uses pdfo subroutine for derivative free optimization which involves conversion of parameters
        to numpy.ndarray.

        Parameters
        ----------
        new_theta
            (D,)-shaped tensor corresponding to the initial value for the mixture component parameter.

        References
        ----------
        - pdfo software: https://www.pdfo.net/index.html

        Returns
        -------
            (D,)-shaped tensor corresponding to the parameter of the new component.
        """
        # assert self.phi.device == torch.device("cpu")
        new_theta = new_theta.cpu().numpy()
        def fct_min_neg_atom_corr(x): return float(self.loss_atom_correlation(torch.from_numpy(x)))
        # fct_min_neg_atom_corr = self._get_residual_correlation_value
        sol = pdfo(fct_min_neg_atom_corr,
                   x0=new_theta,  # Start at current solution
                   bounds=self.bounds_atom,
                   options={'maxfev': self.nb_iter_max_step_1 * new_theta.size}
                   )

        return torch.Tensor(sol.x).to(self.real_dtype).to(self.device)

    def _maximize_atom_correlation_torch(self, new_theta, prefix):
        """
        Step 1 in CLOMP-R algorithm. Find the center giving the most correlated atom to the residual.

        This method uses torch for optimization.

        Parameters
        ----------
        new_theta
            (D,)-shaped tensor corresponding to the initial value for the cluster center.
        prefix
            Prefix the identifier of the list of objective values, if they are stored.

        References
        ----------
        - pdfo software: https://www.pdfo.net/index.html

        Returns
        -------
            (D,)-shaped tensor corresponding to the new center.
        """
        params = [torch.nn.Parameter(new_theta, requires_grad=True)]

        optimizer = self._initialize_optimizer(self.opt_method_step_1, params)

        def closure():
            optimizer.zero_grad()
            loss = self.loss_atom_correlation(params[0])

            if self.store_objective_values:
                ObjectiveValuesStorage().add(float(loss), "{}/maximize_atom_correlation".format(prefix))
            if self.tensorboard:
                self.writer.add_scalar(self.path_template_tensorboard_writer.format("step1/{}".format(prefix)),
                                       loss.item(), i)

            loss.backward()
            return loss

        for i in range(self.maxiter_inner_optimizations):
            self.count_iter_torch_maximize_atom_correlation += 1
            if self.opt_method_step_1 == "lbfgs":
                loss = optimizer.step(closure)  # bfgs takes the loss computation function as argument at each step.
            else:
                loss = closure()
                optimizer.step()

            # Projection step
            with torch.no_grad():
                self.projection_step(new_theta)

            if i != 0:
                relative_loss_diff = torch.abs(previous_loss - loss) / torch.abs(previous_loss)
                if relative_loss_diff.item() <= self.tol_inner_optimizations:
                    break
            previous_loss = torch.clone(loss)

        if self.tensorboard:
            self.writer.flush()
            self.writer.close()

        return new_theta.data.detach()

    def transform_grid(self,function,grad_function,dom,grid,alpha,T):
        
        #print('1')
        #print('domp')
        N = len(grid)
        #print(N)
        #print(N)
        t_grid = []
        for n in list(range(N)):
            #print(n)
            x_tmp = grid[n]
            x_list = [x_tmp]
            #print(x_tmp)
            #r_count = 0
            
            for t in list(range(T)):
                previous_loss = function(x_tmp)
                #print(t)
                if function(x_tmp)<=0:
                    #print('neg')
                    x_tmp_g = Variable(x_tmp,requires_grad = True)
                    x = x_tmp-alpha*(1/(function(x_tmp)))*(grad_function(x_tmp_g))

                    x_tmp = dom.project(x)
                    x_list.append(x_tmp)
                    loss = function(x)
                else:
                    x_tmp_g = Variable(x_tmp,requires_grad = True)
                    x = x_tmp+alpha*(1/(function(x_tmp)))*(grad_function(x_tmp_g))

                    x_tmp = dom.project(x)
                    x_list.append(x_tmp)
                    loss = function(x)
                relative_loss_diff = torch.abs(previous_loss - loss) / torch.abs(previous_loss)
                #print(loss)
                #print(relative_loss_diff)
                # if relative_loss_diff.item() <= self.tol_inner_optimizations:
                    
                #     #print(t)
                #     break
            L = len(x_list)
            t_grid.append(x_list[L-1])
            # f_i = self.get_correlation_function()
            # f_0 = self.get_correlation_function_0()
            # #plot_this_2D_function_and_nodes(f_0,x_list[L-1],x_list,-1,1,100)
            # plot_this_2D_function_and_nodes(f_i,x_list[L-1],x_list,-1,1,100)
            #plot_this_2D_function_and_nodes(function,x_list[L-1],x_list,-1,1,50)
            #plot_this_2D_log_function_and_nodes(function,x_list[L-1],x_list,-1,1,50)
        return t_grid

    def transform_grid_fast(self,function,grad_function,dom,grid,alpha,T):
        
        #print('1')
        #print('domp')
        N = len(grid)
        #print(N)
        #print(N)
        t_grid = []
        f_list = []
        for n in list(range(N)):
            #print(n)
            
            f_list.append(function(grid[n]))
        #print(f_list)
        n_star = np.argmax(f_list)
        x_tmp = grid[n_star]
        x_list = [x_tmp]
            #print(x_tmp)
            #r_count = 0

        # print('the start')
        # print(self.sigmas)
        # print(self.thetas)
        # print(self.alphas)
        # print(self.residual)
        for t in list(range(T)):
            previous_loss = function(x_tmp)
            #print(t)
            #print(function(x_tmp))
            #print(x_tmp)
            #print(function(x_tmp))
            if function(x_tmp)<=0:
                #print('neg')
                x_tmp_g = Variable(x_tmp,requires_grad = True)
                x = x_tmp-alpha*(1/(function(x_tmp)))*(grad_function(x_tmp_g))

                x_tmp = dom.project(x)
                x_list.append(x_tmp)
                loss = function(x)
            else:
                x_tmp_g = Variable(x_tmp,requires_grad = True)
                x = x_tmp+alpha*(1/(function(x_tmp)))*(grad_function(x_tmp_g))

                x_tmp = dom.project(x)
                x_list.append(x_tmp)
                loss = function(x)
            relative_loss_diff = torch.abs(previous_loss - loss) / torch.abs(previous_loss)
            #print(loss)
            #print(relative_loss_diff)
            # if relative_loss_diff.item() <= self.tol_inner_optimizations:
                
            #     #print(t)
            #     break
        L = len(x_list)
        t_grid.append(x_list[L-1])
        #plot_this_2D_function_and_nodes(function,x_list[L-1],x_list,-1,1,100)
        f_i = self.get_iterative_function()
        f_0 = self.get_correlation_function_0()
        #plot_this_2D_function_and_nodes(f_i,x_list[L-1],x_list,-1,1,100)
        #self.plot_sign_hessian_function(f_0,-1,1,100)
        #plot_this_2D_log_function_and_nodes(function,x_list[L-1],x_list,-1,1,50)
        return t_grid


    def transform_grid_gradient_fast(self,function,grad_function,dom,grid,alpha,T):
        
        #print('1')
        #print('domp')
        N = len(grid)
        #print(N)
        #print(N)
        t_grid = []
        f_list = []
        for n in list(range(N)):
            #print(n)
            
            f_list.append(function(grid[n]))
        n_star = np.argmax(f_list)
        x_tmp = grid[n_star]
        x_list = [x_tmp]
            #print(x_tmp)
            #r_count = 0
            
        for t in list(range(T)):
            previous_loss = function(x_tmp)
            #print(t)
            x_tmp_g = Variable(x_tmp,requires_grad = True)
            x = x_tmp+alpha*(grad_function(x_tmp_g))

            x_tmp = dom.project(x)
            x_list.append(x_tmp)
            loss = function(x)
            relative_loss_diff = torch.abs(previous_loss - loss) / torch.abs(previous_loss)
            #print(loss)
            #print(relative_loss_diff)
            # if relative_loss_diff.item() <= self.tol_inner_optimizations:
                
            #     #print(t)
            #     break
        L = len(x_list)
        t_grid.append(x_list[L-1])
        #plot_this_2D_function_and_nodes(function,x_list[L-1],x_list,-1,1,50)
        #plot_this_2D_log_function_and_nodes(function,x_list[L-1],x_list,-1,1,50)
        return t_grid
    
    
    def fit_once(self):
        """
        CLOMP-R algorithm implementation.

        References
        ----------
        Keriven, N., Bourrier, A., Gribonval, R., & Pérez, P. (2018). Sketching for large-scale learning of mixture models.
        Information and Inference: A Journal of the IMA, 7(3), 447-508.
        https://arxiv.org/pdf/1606.02838.pdf
        """
        # todo utiliser plutot une interface de type fit/transform
        n_iterations = self.dct_optim_method_hyperparameters.get("int_grid_size")
        #2*self.size_mixture_K
        nodes_list = []
        for i_iter in range(n_iterations):
            #print(self.residual)
            logger.debug(f'Iteration {i_iter + 1} / {n_iterations}')
            # Step 1: find new atom theta most correlated with residual
            since = time.time()
            
            #s=self.s_sigma
            D = self.thetas_dimension_D
            
            f_z = self.get_correlation_function()
            f_z_0 = self.get_correlation_function_0()
            grad_f_z = self.get_gradient_correlation_function()
            #print(self.sigmas)
            #print(self.thetas)
            #print(self.alphas)
            #print('herll')
            #print(self.dct_optim_method_hyperparameters.get("grid_mode"))
            #print(self.grid_mode)
            if self.dct_optim_method_hyperparameters.get("grid_mode") =='static_grid':
                #print('statiiiic')
                h_d = hypercube(-1,1,D)
                #print(self.residual)
                #if i_iter ==0:
                grid = h_d.grid(self.dct_optim_method_hyperparameters.get("grid_size"))
                grid_0 = grid
            if self.dct_optim_method_hyperparameters.get("grid_mode") =='static_grid_ball':
                #print('statiiiic')
                b_d = ball(np.zeros(D),1,D)
                #print(self.residual)
                #if i_iter ==0:
                grid = b_d.grid(self.dct_optim_method_hyperparameters.get("grid_size"))
                grid_0 = grid
                    
            if self.dct_optim_method_hyperparameters.get("grid_mode") =='shifted_grid':
                #print('sss')
                h_d = hypercube(-1,1,D)
                #n_grid = 500
                #if i_iter ==0:
                tmp_grid = h_d.grid(self.dct_optim_method_hyperparameters.get("grid_size"))
                alpha = self.dct_optim_method_hyperparameters.get("ms_step_size")
                T = self.dct_optim_method_hyperparameters.get("ms_iteration_number")
                #500
                if self.dct_optim_method_hyperparameters.get("opt_mode") =='normal':
                    grid = self.transform_grid(f_z,grad_f_z,h_d,tmp_grid,alpha,T)
                elif self.dct_optim_method_hyperparameters.get("opt_mode") =='fast':
                    grid = self.transform_grid_fast(f_z,grad_f_z,h_d,tmp_grid,alpha,T)
                elif self.dct_optim_method_hyperparameters.get("opt_mode") =='gradient_fast':
                    grid = self.transform_grid_gradient_fast(f_z,grad_f_z,h_d,tmp_grid,alpha,T)

                grid_0 = tmp_grid
                
                
                
            if self.dct_optim_method_hyperparameters.get("grid_mode") =='shifted_grid_ball':
                #print('sss')
                b_d = ball(np.zeros(D),1,D)
                #h_d = hypercube(-1,1,D)
                #n_grid = 500
                #if i_iter ==0:
                tmp_grid = b_d.grid(self.dct_optim_method_hyperparameters.get("grid_size"))
                alpha = self.dct_optim_method_hyperparameters.get("ms_step_size")
                T = self.dct_optim_method_hyperparameters.get("ms_iteration_number")
                #500
                if self.dct_optim_method_hyperparameters.get("opt_mode") =='normal':
                    grid = self.transform_grid(f_z,grad_f_z,b_d,tmp_grid,alpha,T)
                elif self.dct_optim_method_hyperparameters.get("opt_mode") =='fast':
                    grid = self.transform_grid_fast(f_z,grad_f_z,b_d,tmp_grid,alpha,T)
                elif self.dct_optim_method_hyperparameters.get("opt_mode") =='gradient_fast':
                    grid = self.transform_grid_gradient_fast(f_z,grad_f_z,b_d,tmp_grid,alpha,T)

                grid_0 = tmp_grid
                #grid    
            #print('greiiiiid')
            #print(tmp_grid)
            #print(grid)
            

            
            
            n,sigma = self.centroids_recovery_grid(f_z,self.s_sigma,D,'dirac',grid)
            f_i = self.get_correlation_function()
            f_0 = self.get_correlation_function_0()
            #plot_this_2D_function_and_nodes(f_0,x_list[L-1],x_list,-1,1,100)
            plot_this_2D_function_and_nodes(f_i,n,[n],-1,1,100)
            # #recovered_n_list,recovered_w_list,recovered_sigma_list,x_list_lists =self.centroids_recovery_ms(f_z,grad_f_z,0.005,0.01,500,h_d,'dirac')
            # if self.dct_optim_method_hyperparameters.get("model")=='diracs':
            #     n,sigma = self.centroids_recovery_grid(f_z,self.s_sigma,D,'dirac',grid)
            # elif self.dct_optim_method_hyperparameters.get("model")=='gaussians':
            #     n,sigma = self.centroids_recovery_grid(f_z,self.s_sigma,D,'gaussian',grid)


     
            #n = recovered_n_list[0]
            new_theta= n
            nodes_list.append(n)
            #new_theta = self.maximize_atom_correlation(prefix=str(i_iter))
            logger.debug(f'Time for step 1: {time.time() - since}')

            # Step 2: add it to the support
            self.add_atom(new_theta)
            
            
            # Step 2.5: project to find weights
            since = time.time()
            self.alphas = self.find_optimal_weights()
            #print('alphaaaas')
            #print(self.alphas)
            self.sigmas = self.find_sigmas(self.s_sigma,D,f_z_0)
            #print(self.sigmas)
            #print('aaaa')
            #print(self.s_sigma)
            logger.debug(f'Time for step 2.5: {time.time() - since}')
            # The atoms have changed: we must re-compute their sketches matrix
            self.update_current_solution_and_cost(new_current_solution=(self.thetas, self.alphas, self.sigmas))
        
        # Step 3: if necessary, hard-threshold to enforce sparsity
        if self.current_size_mixture > self.size_mixture_K:

            since = time.time()
            # atoms must be normalized so that the weights in beta reflect their importance. See Reference.
            print('beta')
            
            beta = self.find_optimal_weights(normalize_phi_thetas=True, prefix=f"{i_iter}a")
            print(beta)
            
            
            while self.current_size_mixture > self.size_mixture_K:
                # print(type(beta))
                # print(beta.shape)
                index_to_remove = torch.argmin(beta).to(torch.long)
                beta = torch.cat([beta[0:index_to_remove], beta[index_to_remove+1:]])
                self.remove_one_component(index_to_remove)
                logger.debug(f'Time for step 3: {time.time() - since}')
                if index_to_remove == self.size_mixture_K:
                    logger.debug(f"Removed atom is the last one added. Solution is not updated.")
                    #continue

        # Step 4 and 5
        #self.do_step_4_5(prefix=str(i_iter))

        # Final fine-tuning with increased optimization accuracy
        #self.final_fine_tuning()
        #self.alphas /= torch.sum(self.alphas)
        #print(self.thetas)
        #print(self.alphas)
        self.update_current_solution_and_cost(new_current_solution=(self.thetas, self.alphas,self.sigmas))
    
        return nodes_list,grid_0

