# -*- coding: utf-8 -*-

########################################################
# NAIVE FINDER
########################################################
import numpy as np
from scipy.optimize import least_squares
from  basic_function_lookahead  import *


def find_operators_h1_no_reg(data):
    return np.dstack([ data[:,t+1].reshape((-1,1))  @
                       np.linalg.pinv(data[:,t].reshape((-1,1)) ).reshape((1,-1))  
                      for t in range(data.shape[1]-1 ) ])


def find_operators_h1_l2_reg(data, l2_w, smooth_w):
    if smooth_w == 0:
        return np.dstack([ find_operator_under_l2_or_smoothness(data[:,t+1], data[:,t], l2_w, 
                                                            smooth_w, t = t, A_minus = []) 
                          for t in range(data.shape[1]-1 ) ])
    else:
        As = []
        for t in range(data.shape[1]-1 ):
            if t > 0 :  A_minus  = A_next
            else:       A_minus = []
            
            A_next = find_operator_under_l2_or_smoothness(data[:,t+1], data[:,t], l2_w, 
                                                        smooth_w, t = t, A_minus = A_minus) 
            As.append(A_next)
        return np.dstack(As)




def find_operators_under_l0( data, params_thres = {'thres':1, 'num':True, 'perc':False}):

    if  params_thres['thres'] == 0:
        return find_operators_h1_no_reg(data)
    else:
        As = []
        for t in range(data.shape[1]-1 ):
            A_next = data[:,t+1].reshape((-1,1))  @ np.linalg.pinv(data[:,t].reshape((-1,1)) ).reshape((1,-1)) 
            A_next = keep_thres_only(A_next, direction = 'lower' , **params_thres)                  
              
            As.append(A_next)
        return np.dstack(As)

    


def find_operator_under_l1(data, l1_w, seed = 0, params = {} ):
    params ={**{'threshkind':'soft','solver':'spgl1','num_iters':10},**params}
    p = data.shape[0]
    As = []
    for t in range(data.shape[1]-1 ):
        #if per_row:
        A_cur = []
        for el in range(p):
            cur_b = np.array([data[el,t+1]]).reshape((1,-1))
            cur_A = data[:,t].reshape((1,-1))
            A_cur.append(solve_Lasso_style(cur_A, cur_b, l1_w, 
                                           params = params, lasso_params = {}, random_state = seed).reshape((1,-1)))
    

        As.append(np.vstack(A_cur))
        # else:
        #     cur_b = np.array([data[:,t+1]]).reshape((1,-1))
        #     cur_A = data[:,t].reshape((1,-1))
            
            
            
    return np.dstack(As)
        
        
    # raise ValueError('HEREHERE FUTURE')
    # # 
    # if l1_w == 0:
    #     return find_operators_h1_no_reg(data)
    # else:
    #     As = []
    #     for t in range(data.shape[1]-1 ):
    #         if t > 0 :  A_minus  = A_next
    #         else:       A_minus = []
            
    #         A_next = find_operator_under_l2_or_smoothness(data[:,t+1], data[:,t], l2_w, 
    #                                                     smooth_w, t = t, A_minus = A_minus) 
    #         As.append(A_next)
    #     return np.dstack(As)

    
    
    
    
def find_operator_under_l2_or_smoothness( y_plus, y_minus, l2_w = 0, smooth_w = 0, t = 0, A_minus = []):
    # this function is not being durectly called

    y_plus = y_plus.reshape((-1,1))
    y_minus = y_minus.reshape((-1,1))
    if (l2_w == 0 and smooth_w == 0) or (t == 0 and l2_w == 0):
        print('pay attention - no reg')
        return y_plus  @ np.linalg.pinv( y_minus ).reshape((1,-1))  
                           
    if smooth_w > 0 and checkEmptyList(A_minus) and t > 0:
        raise ValueError('you must provide A_minus')
    else:

        A_shape = (len(y_plus.flatten()), len(y_plus.flatten()))
        
        if l2_w > 0 and smooth_w > 0 and t > 0:
            plus_addition = np.hstack([np.zeros(A_shape)*l2_w , A_minus*smooth_w ])
            minus_addition = np.hstack( [np.eye(A_shape[0]) , np.eye(A_shape[0])*smooth_w ])
            
        elif  smooth_w > 0 and t > 0:        
            plus_addition = A_minus*smooth_w
            minus_addition = np.eye(A_shape[0])*smooth_w 
        
        elif  l2_w > 0:        
            plus_addition = np.zeros(A_shape)*l2_w
            minus_addition = np.eye(A_shape[0])*l2_w
            
            
        #print(plus_addition )
        y_plus = np.hstack([y_plus.reshape((-1,1)), plus_addition ] )    
        
        y_minus = np.hstack(    [       y_minus.reshape((-1,1)), minus_addition    ]   )
    return   np.linalg.lstsq(y_minus.T, y_plus.T)[0]       #y_plus @ np.linalg.pinv(( y_minus ))
    





def one_step_prediction(x, As):    
    return np.hstack([x[:,0].reshape((-1,1))] + [ (As[:,:,t] @ x[:,t].reshape((-1,1))).reshape((-1,1)) for t in range(x.shape[1] - 1) ])

def k_step_prediction(x, As, K, store_mid = True):   
    x = x.copy()
    if store_mid:
        stores = []
    for k in range(K):        

        x = one_step_prediction(x, As)    
        if store_mid:
            stores.append(x)
    if store_mid:
        return x, stores
    return x

import seaborn as sns


def plot_elements_of_mov( mat,  ax = [] ,  fig = [],colors = [], legend_params = {}, plot_params = {}, to_legend = False, 
                         type_plot = 'plot'):
    if checkEmptyList(colors):
        colors = create_colors(len(mat[:,:,0].flatten()))
    if is_1d(colors):
        colors = colors.reshape((mat[:,:,0].shape[0],mat[:,:,0].shape[1]))
    else:
        colors = colors.reshape((mat[:,:,0].shape[0],mat[:,:,0].shape[1],3))
    # 3d mat
    if checkEmptyList(ax):
        fig, ax = plt.subplots()
    if  type_plot == 'plot':
        [
         
         [
          ax.plot( mat[i,j,:], color = colors[i,j], label = '%d'%(mat.shape[0]*j + i), **plot_params)
          
          for j in range(mat.shape[1])
          ] for i in range(mat.shape[0])
         
         ] 
        
        if to_legend:
            ax.legend(**legend_params)
    elif  type_plot == 'heatmap':
        mat_2d = np.vstack([
         
         np.vstack([
           mat[i,j,:]
          
          for j in range(mat.shape[1])
          ]) for i in range(mat.shape[0])
         
         ] )
        sns.heatmap(mat_2d, ax = ax, **plot_params)
        
        
    

def find_weight(k, e, sigma1 = 12, sigma2 = 1.1):    
    """
    Calculate the weight using a specified formula.
    
    Parameters:
    - k (float): Input parameter.
    - e (float): Input parameter.
    - sigma1 (float, optional): Standard deviation parameter for k. Default is 12.
    - sigma2 (float, optional): Exponent parameter for (1 + |e|). Default is 1.1.
    
    Returns:
    float: Weight calculated based on the input parameters.
    """
    return np.exp(-k*sigma1)*(1+np.abs(e))**sigma2

def reevaluate_A_under_mask(y_plus, y_minus, A_mask):
    """
    Reevaluate matrix A based on masks and input vectors.
    
    Parameters:
    - y_plus (numpy.ndarray): Vector with positive values.
    - y_minus (numpy.ndarray): Vector with negative values.
    - A_mask (numpy.ndarray): Mask specifying the structure of matrix A.
    
    Returns:
    numpy.ndarray: Reevaluated matrix A based on the provided masks and vectors.
    """
    if 0 in A_mask:
        A_mask = A_mask != 0
    if np.sum(A_mask) == len(A_mask.flatten()):
        if is_1d(y_plus) and is_1d(y_minus):
            return y_plus.reshape((-1,1))  @ np.linalg.pinv(y_minus).reshape((1,-1))
        return y_plus  @ np.linalg.pinv(y_minus)
    
    else:
        A_new = np.zeros(A_mask.shape)
        for row in range(A_mask.shape[0]) :
            cols = np.where(A_mask[row])[0]

            y_minus_row = y_minus[A_mask[row]]
            if is_1d(y_plus) and is_1d(y_minus):
                gram = np.outer(y_minus_row, y_minus_row)
                #print(gram.shape)
                A_row  = y_plus[row]*y_minus_row.reshape((1,-1)) @ np.linalg.pinv(gram) 
            else:
                A_row  = y_plus[row]  @ np.linalg.pinv(y_minus_row ) 
                
            A_new[row, cols] = A_row
        return A_new
    

def  infer_A_under_constraint(y_plus, y_minus, constraint = 'l0', w_reg = 3, params = {}, reeval = True , is1d_dir = 0):
     #future not l0   
     if is_1d(y_plus):
         if is1d_dir == 0:
             y_plus = y_plus.reshape((-1,1 ))
         else:
             y_plus = y_plus.reshape((1,-1 ))
            
             
     if is_1d(y_minus):
         if is1d_dir == 0:
             y_minus = y_minus.reshape((-1,1 ))
             shape_inv = (1,-1)
         else:
             y_minus = y_minus.reshape((1,-1 ))
             shape_inv = (1,-1)
     else:
         shape_inv = y_minus.shape[::-1]
        
     try:    
         A_hat = y_plus @ np.linalg.pinv(y_minus)
     except:
         print(y_minus)
         input('ok?')
     A_hat = keep_thres_only(A_hat, direction = 'lower' , thres = w_reg, num = True, perc = False)    
     
     
     if reeval:
         A_hat = reevaluate_A_under_mask(y_plus, y_minus, A_hat != 0)
     return A_hat
    

def find_propagations(As, t, data, k, constraint = 'l0', w_reg = 3, shape_return = (-1)):
    # t is time to update
    # As are all As
    data_0 = data[:,t - k + 1]
    if k > 1:
        min_a = np.max([0,t-k])
        data_prop_t_minus_1 =  propagate_dyn_based_on_operator(data_0, As[:,:,min_a: t-1], max_t = t-1 - min_a)[:,-1]
    else:
        data_prop_t_minus_1 = data_0.copy()
        
    data_t = data[:,t]
    
    
    return data_t.reshape(shape_return), data_prop_t_minus_1.reshape(shape_return)
    
    
    


def train_LOOKAHEAD(data, K_f = 10, As = [], constraint = 'l0', w_reg = 1, max_error = 100, sigma1 = 7, sigma2 = 12, norm_w = True, reeval = True, backprop = False, max_iter = 500,
                    seed = 0, store_As = True):
    np.random.seed(seed)
    if checkEmptyList(As):
        #if constraint == 'l0':
        #    As = find_operators_under_l0(data,  params_thres = {'thres':w_reg, 'num':True, 'perc':False})
        # else:
        #     raise ValueError(':(')
        As = np.random.rand(data.shape[0], data.shape[0], data.shape[1]-1)
    
    iter_num = 0
    error  = np.inf
    T = data.shape[1]
    start_time = time.time()
    errors = []
    if store_As:
        store_As_dict = {}
    while error > max_error and iter_num < max_iter:    
        
        print('iter num %d'%iter_num)
        for t in range(1,T):
            K = np.min([K_f, t])
            """
            check lookahead
            """
            x, stores = k_step_prediction(data, As, K, store_mid = True)
            
            """
            matrix of K x T
            """
            es = np.vstack([np.sum((x_i - data)**2, 0) for  x_i in stores])
            
            """
            weight for each time point
            """
            w_t = find_weight(np.arange(1,K+1), es[:,t], sigma1 = sigma1, sigma2 = sigma2) 
            if norm_w:
                w_t /= w_t.sum()
            # print(w_t)
            # print(3)
            # print(es[:,t])
            # print('3')
            #input('w t ok')
            data_plus_data_minus = np.hstack([np.dstack(find_propagations(As, t, data, k, constraint = constraint, w_reg = w_reg,  shape_return = (-1, 1, 1))) 
                                           for k in range(1,K+1)])
            data_plus_data_minus_w = data_plus_data_minus  * w_t.reshape((1,-1,1))  
            A_t = infer_A_under_constraint(data_plus_data_minus_w[:,:, 0], data_plus_data_minus_w[:,:,1], reeval = reeval, w_reg = w_reg)
            # print('hhhhhhhhhhhhhhhhh')
            # print(As[:,:,t-1])
            As[:,:,t-1] = A_t
            # print('===========')
            # print(As[:,:,t-1])
            # print('kkkkkkkkkkkkkkkk')            
            error = np.sum((x - data)**2)
            errors.append(error)
            
        if backprop:
            raise ValueError('future imp')
        if store_As:
            store_As_dict[iter_num] = As
        iter_num += 1
    end_time = time.time()

    # Calculate and print the elapsed time
    elapsed_time = end_time - start_time # in sec    
    if store_As:
        return As, errors, x, elapsed_time, store_As_dict
    return As, errors, x, elapsed_time
    
def choose_random_min_max(min_v, max_v, seed = 0):
    np.random.seed(seed)
    return np.random.rand()*(max_v - min_v) + min_v
    
    
    
def find_lorenz_with_uncorr_x_and_y(xmin = -1, xmax = 1, ymin = -1, ymax = 5, zmin = -5, zmax = 5, dt_min = 0.0010,dt_max = 0.04, max_t_max = 10,max_t_min = 1,
                  sigma_min = 8, sigma_max = 12, beta_min = 4/3, beta_max = 10/3,  rho_min = 10, rho_max = 38, return_operators = True, option = 2, num_iters = 100,
                  seed = 0):
    params = []
    corrs = []
    start_time = time.time()
    for seed in range(num_iters):
        

        
        psi0 = [choose_random_min_max(xmin, xmax, seed),choose_random_min_max(ymin, ymax,seed),choose_random_min_max(zmin, zmax, seed) ]    
        rho = choose_random_min_max(rho_min, rho_max, seed)
        sigma = choose_random_min_max(sigma_min, sigma_max, seed)
        beta = choose_random_min_max(beta_min, beta_max, seed)
        dt =  choose_random_min_max(dt_min, dt_max, seed)
        max_t=  int(choose_random_min_max(max_t_min, max_t_max, seed)) #.astype(int)
        params_dict = {
        'psi0': psi0,
        'dt': dt,
        'max_t': max_t,
        'sigma': sigma,
        'beta': beta,
        'rho': rho,
        'return_operators': return_operators,
        'option': option
        }
        lorenz_mat1, _ = create_lorenz(psi0, max_t = max_t, option = 2, dt = dt)
        if np.max(np.abs(lorenz_mat1)) < 1000:
            
            cor = np.corrcoef(lorenz_mat1)[0,1]
            corrs.append(cor)
            params.append(params_dict)
    end_time = time.time()

    # Calculate and print the elapsed time
    elapsed_time = end_time - start_time # in sec
    return corrs, params, elapsed_time
    


    
    
    
    
    



########################################################
# Multi-Step Finder
########################################################

def identify_system_from_data(max_apply, max_predict ):
    # max apply = integrate into the model optimization
    # max_predict - how long to predict?

    pass


