# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import warnings
import numpy as np
import matplotlib.ticker as ticker
import numpy.linalg as linalg
from sklearn import linear_model
from sklearn.linear_model import OrthogonalMatchingPursuit
import time
#in_local = True
#try:
import pylops
import os

    
def create_legend(dict_legend, size = 30, save_formats = ['.png','.svg'], 
                  save_addi = 'legend' , dict_legend_marker = {}, 
                  marker = '.', style = 'plot', s = 500, to_save = True, plot_params = {'lw':5},
                  save_path = os.getcwd(), params_leg = {}, figsize = (8,8)):
    fig, ax = plt.subplots(figsize = figsize)
    if style == 'plot':
        [ax.plot([],[], 
                 c = dict_legend[area], label = area, marker = dict_legend_marker.get(area), **plot_params) for area in dict_legend]
    else:
        if len(dict_legend_marker) == 0:
            [ax.scatter([],[], s=s,c = dict_legend.get(area), label = area, marker = marker, **plot_params) for area in dict_legend]
        else:
            [ax.scatter([],[], s=s,c = dict_legend[area], label = area, marker = dict_legend_marker.get(area), **plot_params) for area in dict_legend]
    ax.legend(prop = {'size':size},**params_leg)
    remove_edges(ax, left = False, bottom = False)
    fig.tight_layout()
    if to_save:
        [fig.savefig(save_path + os.sep + 'legend_areas_%s%s'%(save_addi,type_save)) 
         for type_save in save_formats]
        
        

def checkEmptyList(obj):
    """
    Check if the given object is an empty list.

    Args:
        obj (object): Object to be checked.

    Returns:
        bool: True if the object is an empty list, False otherwise.

    """    
    return isinstance(obj, list) and len(obj) == 0
def d3tod32(mat)        :
    mat_2d = np.vstack([
     
     np.vstack([
       mat[i,j,:]
      
      for j in range(mat.shape[1])
      ]) for i in range(mat.shape[0])
     
     ] )
    return mat_2d


def remove_edges(ax, include_ticks = False, top = False, right = False, bottom = True, left = True):
    ax.spines['top'].set_visible(top)    
    ax.spines['right'].set_visible(right)
    ax.spines['bottom'].set_visible(bottom)
    ax.spines['left'].set_visible(left)  
    if not include_ticks:
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])

def add_labels(ax, xlabel='X', ylabel='Y', zlabel='', title='', xlim = None, ylim = None, zlim = None,xticklabels = np.array([None]),
               yticklabels = np.array([None] ), xticks = [], yticks = [], legend = [], 
               ylabel_params = {'fontsize':19},zlabel_params = {'fontsize':19}, xlabel_params = {'fontsize':19}, 
               title_params = {'fontsize':19}, format_xticks = 0, format_yticks = 0):
  """
  This function add labels, titles, limits, etc. to figures;
  Inputs:
      ax      = the subplot to edit
      xlabel  = xlabel
      ylabel  = ylabel
      zlabel  = zlabel (if the figure is 2d please define zlabel = None)
      etc.
  """
  if xlabel != '' and xlabel != None: ax.set_xlabel(xlabel, **xlabel_params)
  if ylabel != '' and ylabel != None:ax.set_ylabel(ylabel, **ylabel_params)
  if zlabel != '' and zlabel != None:ax.set_zlabel(zlabel,**zlabel_params)
  if title != '' and title != None: ax.set_title(title, **title_params)
  if xlim != None: ax.set_xlim(xlim)
  if ylim != None: ax.set_ylim(ylim)
  if zlim != None: ax.set_zlim(zlim)
  
  if (np.array(xticklabels) != None).any(): 
      if len(xticks) == 0: xticks = np.arange(len(xticklabels))
      ax.set_xticks(xticks);
      ax.set_xticklabels(xticklabels);
  if (np.array(yticklabels) != None).any(): 
      if len(yticks) == 0: yticks = np.arange(len(yticklabels)) +0.5
      ax.set_yticks(yticks);
      ax.set_yticklabels(yticklabels);
  if len(legend)       > 0:  ax.legend(legend)
  if format_xticks > 0:
      ax.xaxis.set_major_formatter(FormatStrFormatter('%.%df'%format_xticks))
  if format_yticks > 0:      
      ax.yaxis.set_major_formatter(FormatStrFormatter('%.%df'%format_yticks))
      


def create_lorenz(psi0 = [0.1, 0.1, 0], dt = 0.010, max_t = 3,
                  sigma = 10, beta = 8/3, rho = 28, return_operators = True, option = 1):
    # {'psi0': [0.2127086755529508, 2.6381260266588527, 1.0635433777647538],
    # 'dt': 0.024647819173282542,
    # 'max_t': 6,
    # 'sigma': 10.425417351105901,
    # 'beta': 2.546042008886284,
    # 'rho': 26.97792145774131,
    # 'return_operators': True,
    # 'option': 2}
    psi = np.array(psi0).reshape((-1,1))
    """
    define A
    """
    As = []
    for t in np.arange(0, max_t, dt):
        A = create_lorenz_mat(psi[:,-1], sigma, beta, rho, option = option)
        psi_next =  (A*dt + np.eye(A.shape[0])) @ psi[:,-1]
        psi = np.hstack([psi, psi_next.reshape((-1,1))])
        if return_operators:
            As.append(A*dt + np.eye(A.shape[0])) 
    if return_operators:
        return psi, As
    return psi
    
def create_3d_ax(num_rows, num_cols, params = {}):
    fig, ax = plt.subplots(num_rows, num_cols, subplot_kw = {'projection': '3d'}, **params)
    return  fig, ax    


def plot_3d(mat, params_fig = {}, fig = [], ax = [], params_plot = {}, type_plot = 'plot'):
    # 
    if checkEmptyList(ax):
        fig, ax = create_3d_ax(1,1, params_fig)
    if type_plot == 'plot':    
        ax.plot(mat[0], mat[1], mat[2], **params_plot)
    else:
        ax.scatter(mat[0], mat[1], mat[2], **params_plot)
    
def create_colors(len_colors, perm = [0,1,2], style = 'random', cmap  = 'viridis', seed = 0, reduce_green = 0.4):
    """
    Create a set of discrete colors with a one-directional order
    Input: 
        len_colors = number of different colors needed
    Output:
        3 X len_colors matrix decpiting the colors in the cols
    """
    np.random.seed(seed)
    if style == 'random':
        colors = np.random.rand(3, len_colors)
        colors[1] = colors[1]*reduce_green
    else:
        cmap = plt.get_cmap(cmap) 
        # Create an array of values ranging from 0 to 1 to represent positions in the colormap
        positions = np.linspace(0, 1, len_colors)

        colors = [cmap(pos) for pos in positions]

    return colors
    

    
    
def create_lorenz_mat(psi, sigma, beta, rho, option = 1):
    """
    Generate the matrix A for the Lorenz system based on the given parameters.

    Parameters:
    - psi (list): A list containing the initial values [x0, y0, z0].
    - sigma (float): Parameter controlling the rate of change of x.
    - beta (float): Parameter controlling the behavior of the system.
    - rho (float): Parameter controlling the convective flow.

    Returns:
    numpy.ndarray: The matrix A for the Lorenz system.

    Example:
    >>> psi = [1, 2, 3]
    >>> sigma = 10
    >>> beta = 8/3
    >>> rho = 28
    >>> result = create_lorenz_mat(psi, sigma, beta, rho)
    >>> print(result)
    array([[-10,  10,   0],
           [ 28,  -1,  -1],
           [  0,   3,  -8/3]])
    """    
    x = psi[0]
    if option == 1:
        row1 = [-sigma, sigma, 0]
        row2 = [rho, -1, -x]
        row3 = [0, x, -beta]
    else:
        z = psi[2]
        row1 = [-sigma, sigma, 0]
        row2 = [rho - z, -1, 0]
        row3 = [0, x, - beta]
        
    A = np.vstack([ row1, row2, row3])
    return A
    
    
def add_basic_axes(ax = [], fig = [], max_z = 1, max_x = 1, max_y = 1, 
                   min_z = 0, min_x = 0, min_y = 0,  
                   params_subplot = {},
                   params_plot = {'color' : 'black', 'w': 4,  'ls':'-' ,'mutation':20, 
                                  'arrowstyle':"-|>", 
                                  'linewidth': 2}, remove_back = True, 
                   remove_grid = True,  remove_axes = False, remove_ticks = False):    
    # 
    if checkEmptyList(ax):
        fig, ax = create_3d_ax(1, 1, params_subplot)
    
    dx = max_x - min_x
    dy = max_y - min_y
    dz = max_z - min_z 
    # """
    # x
    # """ 
    # ax.arrow3D(min_x,min_y,min_z,
    #       dx,dy,dz,
    #        mutation_scale=params_plot['mutation'],
    #        arrowstyle=params_plot['arrowstyle'],
    #        linestyle=params_plot['ls'], color = params_plot['color'], 
    #        linewidth = params_plot['linewidth'])        
    # #y 
    # ax.arrow3D(min_x,min_y,min_z,      dx,dy,dz,
    #         mutation_scale=params_plot['mutation'],
    #         arrowstyle=params_plot['arrowstyle'],
    #         linestyle=params_plot['ls'], color = params_plot['color'],
    #         linewidth = params_plot['linewidth'])   
    
    # #z 
    # ax.arrow3D(min_x,min_y,min_z,
    #      dx,dy,dz,
    #         mutation_scale=params_plot['mutation'],
    #         arrowstyle=params_plot['arrowstyle'],
    #         linestyle=params_plot['ls'], color = params_plot['color'],
    #         linewidth = params_plot['linewidth'])
    
    to_remove_back(ax, remove_back, remove_grid,  remove_axes, remove_ticks)
        
        
def to_remove_back(ax, remove_back = True, remove_grid = True, remove_axes = False, remove_ticks = False):
    if remove_back:
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False

    if remove_grid:
        ax.grid = False
    if remove_ticks:
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
    if remove_axes:
        ax.axis('off')
    


    
#def     


"""
3d plotting
"""
from mpl_toolkits.mplot3d.axes3d import Axes3D
from mpl_toolkits.mplot3d.proj3d import proj_transform
from matplotlib.patches import FancyArrowPatch
class Arrow3D(FancyArrowPatch):

    def __init__(self, x, y, z, dx, dy, dz, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._xyz = (x, y, z)
        self._dxdydz = (dx, dy, dz)

    def draw(self, renderer):
        x1, y1, z1 = self._xyz
        dx, dy, dz = self._dxdydz
        x2, y2, z2 = (x1 + dx, y1 + dy, z1 + dz)

        xs, ys, zs = proj_transform((x1, x2), (y1, y2), (z1, z2), self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)
        
    def do_3d_projection(self, renderer=None):
        x1, y1, z1 = self._xyz
        dx, dy, dz = self._dxdydz
        x2, y2, z2 = (x1 + dx, y1 + dy, z1 + dz)

        xs, ys, zs = proj_transform((x1, x2), (y1, y2), (z1, z2), self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))

        return np.min(zs) 
    
def _arrow3D(ax, x, y, z, dx, dy, dz, *args, **kwargs):
    '''Add an 3d arrow to an `Axes3D` instance.'''

    arrow = Arrow3D(x, y, z, dx, dy, dz, *args, **kwargs)
    ax.add_artist(arrow)


setattr(Axes3D, 'arrow3D', _arrow3D)    


def is_1d(mat):
    if isinstance(mat,list): mat = np.array(mat)
    elif isinstance(mat, np.ndarray): pass
    else: raise ValueError('Mat must be numpy array or a list')
    return np.max(mat.shape) == len(mat.flatten())



def solve_Lasso_style(A, b, l1, params = {}, lasso_params = {},random_state = 0, nouter = 50,
                      ):
  """
      Solves the l1-regularized least squares problem
          minimize (1/2)*norm( A * x - b )^2 + l1 * norm( x, 1 ) 
          
    Parameters
    ----------
    A : TYPE
        DESCRIPTION.
    b : TYPE
        DESCRIPTION.
    l1 : float
        scalar between 0 to 1, describe the reg. term on the cofficients.    
    params : TYPE, optional
        DESCRIPTION. The default is {}.
    lasso_params : TYPE, optional
        DESCRIPTION. The default is {}.
    random_state : int, optional
        random state for reproducability. The default is 0.

    Raises
    ------
    NameError
        DESCRIPTION.

    Returns
    -------
    x : np.ndarray
        the solution for min (1/2)*norm( A * x - b )^2 + l1 * norm( x, 1 ) .

  lasso_options:
               - 'inv' (least squares)
               - 'lasso' (sklearn lasso)
               - 'fista' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.FISTA.html)
               - 'omp' (https://pylops.readthedocs.io/en/latest/gallery/plot_ista.html#sphx-glr-gallery-plot-ista-py)
               - 'ista' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.ISTA.html)       
               - 'IRLS' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.IRLS.html)
               - 'spgl1' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.SPGL1.html)
               
               
               - . Refers to the way the coefficients should be claculated (inv -> no l1 regularization)
  """ 
  params = {**{'threshkind':'soft','solver':'spgl1','num_iters':10}, **params}

  if np.isnan(A).any():
      print('there is a nan in A')
      #input('ok? solve_Lasso_style')
  if len(b.flatten()) == np.max(b.shape):
      b = b.reshape((-1,1))
  if 'solver' not in params.keys():
      warnings.warn('Pay Attention: Using Default (inv) solver for updating A. If you want to use lasso please change the solver key in params to lasso or another option from "solve_Lasso_style"')
  params = {**{'threshkind':'soft','solver':'spgl1','num_iters':10}, **params}

  if params['solver'] == 'inv' or l1 == 0:

      #input('jgkljglkfjkgjf')
      if is_1d(A):
          pinv_A = linalg.pinv(A).reshape((-1,1))

      else:
          pinv_A = linalg.pinv(A)
      x = pinv_A @ b.reshape((-1,1))

  elif params['solver'] == 'lasso' :
      #fixing try without warm start
    clf = linear_model.Lasso(alpha=l1,random_state=random_state, **lasso_params)

    #input('ok?')
    clf.fit(A,b.flatten() )     #reshape((-1,1))
    x = np.array(clf.coef_)

  elif params['solver'].lower() == 'fista' :
      Aop = pylops.MatrixMult(A)
  
      #if 'threshkind' not in params: params['threshkind'] ='soft'
      #other_params = {'':other_params[''],
      x = pylops.optimization.sparsity.FISTA(Aop, b.flatten(), niter=params['num_iters'],
                                             eps = l1 , threshkind =  params.get('threshkind') )[0]
  elif params['solver'].lower() == 'ista' :

      #fixing try without warm start
      if 'threshkind' not in params: params['threshkind'] ='soft'
      Aop = pylops.MatrixMult(A)
      x = pylops.optimization.sparsity.ISTA(Aop, b.flatten(), niter=params['num_iters'] , 
                                                 eps = l1,threshkind =  params.get('threshkind'))[0]
      
  elif params['solver'].lower() == 'omp' :
      #print(A.shape[1] - l1)
      #input('?')
      omp = OrthogonalMatchingPursuit(n_nonzero_coefs=A.shape[1] - l1, fit_intercept   = False)
      omp.fit(A,b)
      #Aop = pylops.MatrixMult(A)

      x  = omp.coef_ # pylops.optimization.sparsity.OMP(Aop, b.flatten(),                                                  niter_outer=params['num_iters'], sigma = l1)[0]     
  elif params['solver'].lower() == 'spgl1' :
      #print('here spgl1!!!!!!!!!!')
      Aop = pylops.MatrixMult(A)
      x = pylops.optimization.sparsity.SPGL1(Aop, b.flatten(),iter_lim = params['num_iters'],  tau = l1)[0]      
      
  elif params['solver'].lower() == 'irls' :
   
      Aop = pylops.MatrixMult(A)
      
      #fixing try without warm start
      x = pylops.optimization.sparsity.IRLS(Aop, b.flatten(),  nouter = nouter, espI = l1)[0]      
  else:     
    raise NameError('Unknown update c type')  
  return x

import pandas as pd    
def spec_corr(v1,v2, to_abs = True):
  """
  absolute value of correlation
  """
  corr = np.corrcoef(v1.flatten(),v2.flatten())
  if to_abs:
      return np.abs(corr[0,1])
  return corr[0,1]
    
    
    
def propagate_dyn_based_on_operator(x0, As, max_t): # - MULTI STEP PREDICTION
    """
    Propagate the dynamic system based on a given set of operators for multi-step prediction.

    Parameters:
    - x0 (numpy.ndarray): Initial state vector.
    - As (numpy.ndarray): 2D or 3D array of operators. If 2D, it's broadcasted to create a 3D array for each time step.
      If 3D, the third dimension should match the number of time steps (max_t).
    - max_t (int): Maximum number of time steps for prediction.

    Returns:
    - numpy.ndarray: Array containing the propagated state vectors for each time step.

    Raises:
    - ValueError: If the third dimension of As does not match max_t.

    """
    if len(As.shape) == 2:
        As = np.dstack([As]*max_t)
    elif As.shape[2] != max_t:
        raise ValueError('Max t does not fit A')
    x = x0.reshape((-1,1))
    for t in range(max_t):        
        x = np.hstack([x,  (As[:,:,t] @ x[:,-1].reshape((-1,1)) ).reshape((-1,1)) ])
    return x
    
    
    





def keep_thres_only(mat, thres, direction = 'lower', perc = False, num = False):
    """
    Reset to zero some elements, keep only values above/below a threshold.

    Parameters:
    - mat (numpy.ndarray): The input matrix.
    - thres (float): The threshold value. Elements below/above this value will be set to zero.
    - direction (str, optional): Direction to apply the threshold. 'lower' (default) sets elements below the threshold to zero,
      'upper' sets elements above the threshold to zero.
    - perc (bool, optional): If True, interpret thres as a percentile value. If thres is less than 1, it's treated as a percentage.
    - num (bool, optional): If True, interpret thres as the number of smallest/largest elements to keep.

    Returns:
    - numpy.ndarray: A new matrix with elements below/above the threshold set to zero.

    Raises:
    - ValueError: If both perc and num are provided, or if perc is True and thres is not in the range (0, 1).
    """    
    # reset to zero some elements, keep only perc
    # perc is percentile
    # num is how many zeros
    mat = mat.copy()
    if thres == 0:
        return mat
    if perc and num:
        raise ValueError('you must provide only perc OR  num, or neither')
    if perc and thres < 1:
        thres *= 100
        thres = np.percentile(np.abs(mat.flatten()), thres)
    if num and thres > 0:
        mat_ord = np.sort(np.abs(mat.flatten()))
        thres = mat_ord[int(thres) - 1]
        
    mat = mat.copy()
    if direction == 'lower':
        mat[np.abs(mat) <= thres] = 0 
    else:
        mat[np.abs(mat) >= thres] = 0 
    return mat














# from matplotlib.text import Annotation
# from mpl_toolkits.mplot3d.axes3d import Axes3D
# from mpl_toolkits.mplot3d.proj3d import proj_transform
# # https://gist.github.com/WetHat/1d6cd0f7309535311a539b42cccca89c

# class Annotation3D(Annotation):

#     def __init__(self, text, xyz, *args, **kwargs):
#         super().__init__(text, xy=(0, 0), *args, **kwargs)
#         self._xyz = xyz

#     def draw(self, renderer):
#         x2, y2, z2 = proj_transform(*self._xyz, self.axes.M)
#         self.xyz = (x2, y2, z2)
#         super().draw(renderer)
     
# def _annotate3D(ax, text, xyz, *args, **kwargs):
#     '''Add anotation `text` to an `Axes3d` instance.'''
    
#     annotation = Annotation3D(text, xyz, *args, **kwargs)
#     ax.add_artist(annotation)
# setattr(Axes3D, 'annotate3D', _annotate3D)       
    
    
    
    
    # ax.annotate3D('point 2', (0, 1, 0),
    #               xytext=(-30, -30),
    #               textcoords='offset points',
    #               arrowprops=dict(ec='black', fc='white', shrink=2.5))
    