# -*- coding: utf-8 -*-
"""
Created on Thu Oct 27 15:12:10 2022

@author: anonymous
"""
import os

    
from main_functions_sib_synth import *
new_directory = os.getcwd()



type_grannet = 'synth'  
to_recreate = False  


params_update = {'max_learn': 500,                                                                # Maximum number of steps in learning 
    'mean_square_error': 0.1,

    'deal_nonneg': 'make_nonneg',                                                                 # can be 'make_nonneg' or 'GD'
    'epsilon' : decide_value(type_answer, 0.1, np.random.rand(), 'val'),                          # Default tau values to be spatially varying
    'l1': decide_value(type_answer, 0.1, np.random.rand(), 'val'),                                # Default lambda parameter is 0.6
    'l2': decide_value(type_answer, 0.1, np.random.rand(), 'val'),                                # Default Forbenius norm parameter is 0 (don't use)
    'l3': decide_value(type_answer, 0.1, np.random.rand(), 'val'),                                # Default Dictionary continuation term parameter is 0 (don't use)
    'lamContStp': decide_value(type_answer, 0.1, np.random.rand(), 'val')*0.9,                    # Default multiplicative change to continuation parameter is 1 (no change)
    'l4': decide_value(type_answer, 0.1, np.random.rand(), 'val')*0.1,                            # Default Dictionary correlation regularization parameter is 0 (don't use)
    'beta': 0.9*decide_value(type_answer, 0.1, np.random.rand(), 'val'),                          # Default beta parameter to 0.09
    'maxiter': 0.01,                                                                              # Default the maximum iteration to whenever Delta(Dictionary)<0.01
    'numreps': 2,                                                                                 # Default number of repetitions for RWL1 is 2
    'tolerance': 1e-7*decide_value(type_answer, 0.5, 0.8*np.random.rand(), 'val'),                # Default tolerance for TFOCS calls is 1e-8

    'likely_from' : decide_value(type_answer, 'gaussian' , 'gaussian', 'val') ,                   # Default to a gaussian likelihood ('gaussian' or'poisson')
    'step_s': decide_value(type_answer, 0.2, 0.1+0.9*np.random.rand(), 'val'),                    # Default step to reduce the step size over time 
                                                                       
    'step_decay': 0.995+ 0.05*decide_value(type_answer, 0.9, 0.999*np.random.rand(), 'val'),      # Default step size decay (only needed for grad_type ='norm')
                                       
    'dict_max_error': 0.01,                                                                       # Default learning tolerance: stop when Delta(Dictionary)<0.01
    'p': 10,                                                                                      # Default number of dictionary elements is a function of the data
    'verb': 1,                                                                                    # Default to no verbose output
  
    'GD_iters': 1* decide_value(type_answer, 1, np.random.randint(1,5), 'val'),                   # Default to one GD step per iteration
    'bshow': 0,                                                                                   # Default to no plotting
                                                                                                  # Default to not having negativity constraints
    'nonneg': decide_value(type_answer, False, False,'val')   ,                                   # Default to not having negativity constraints on the coefficients
    'plot': False,                                                                                # Default to not plot spatial components during the learning
    'updateEmbed' : False,                                                                        # Default to not updateing the graph embedding based on changes to the coefficients
    'mask': [],                                                                                   # for masked images (widefield data)
    'normalizeSpatial' : False,                                                                   # default behavior - time-traces are unit norm. when true, spatial maps normalized to max one and time-traces are not normalized     
     'patchSize': 50, 
     'motion_correct': False, 
     'kernelType': 'embedding',
     'reduceDim': decide_value(type_answer, False, np.random.choice([False, True]), 'val'),    
     'w_time': 0,
     'n_neighbors': decide_value(type_answer,49, np.random.randint(5,50), 'val'),    
     'n_comps':5,
     'solver_qp':'quadprog',
     'solver': decide_value(type_answer,'inv', np.random.choice(['spgl1','inv', 'lasso']), 'val'),   
     'nullify_some': False , 
     'norm_by_lambdas_vec': decide_value(type_answer,False, np.random.choice([False, True]), 'val'),  
     'min_max_data': False,
     'GD_type': 'full_ls_cor', 
     'multi':'med',                                                                                  # can be med, sqrt, none
     'thresANullify':-50, #0,
     'CI': {
     'xmin' : 151,#151
     'xmax' : 200,#350                                                                               # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
     'ymin' : 151,#101
     'ymax' : 200,#300 
     },
     'VI_crop': {
     'xmin' : 120,#151
     'xmax' : 270,#350                                                                               # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
     'ymin' : 0,#101
     'ymax' : -1,#300 
     },
     'VI_crop_very': {
     'xmin' : 120,#151
     'xmax' : 170,#350                                                                               # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
     'ymin' : 20,#101
     'ymax' : 70,#300 
     },     
     'VI_full': {
     'xmin' : 0,#151
     'xmax' : -1,#350                                                                                # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
     'ymin' : 0,#101
     'ymax' : -1,#300 
     },
     'area2': {
     'xmin' : 0,#151
     'xmax' : 'n',#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
     'ymin' : 0,#101
     'ymax' : 'n',#300 
     },
     'trends': {
     'xmin' : 0,#151
     'xmax' : 'n',#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
     'ymin' : 0,#101
     'ymax' : 'n',#300 
     },     
     'VI_HPC': {
     'xmin' : 0,#151
     'xmax' : 'n',#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
     'ymin' : 0,#101
     'ymax' : 'n',#300 
     },       
     'VI_crop_long': {
     'xmin' : 0,#151
     'xmax' : 128,#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
     'ymin' : 200,#101
     'ymax' : 328,#300 
     },     
     'VI_crop_long2': {
     'xmin' : 20,#151
     'xmax' : 100,#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
     'ymin' : 200,#101
     'ymax' : 280,#300 
     },  
    'synth': {
    'xmin' : 0,#151
    'xmax' : 70,#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
    'ymin' : 0,#101
    'ymax' : 70,#300 
     },   
    'synth_grannet':
        {
    'xmin' : 0,#151
    'xmax' : 'n',#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
    'ymin' : 0,#101
    'ymax' : 'n',#300 
            },
    'trends_grannet' :
    {
    'xmin' : 0,#151
    'xmax' : 'n',#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
    'ymin' : 0,#101
    'ymax' : 'n',#300 
            },
    'neuro_bump_short'  :
    {
    'xmin' : 0,#151
    'xmax' : 'n',#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
    'ymin' : 0,#101
    'ymax' : 'n',#300 
            },      
    'neuro_bump_short_short'  :
    {
    'xmin' : 0,#151
    'xmax' : 'n',#350                                                            # Can sub-select a portion of the full FOV to test on a small section before running on the full dataset
    'ymin' : 0,#101
    'ymax' : 'n',#300 
            },   
     'use_former_kernel' : False,
     'usePatch' : False, 
     'portion' :True,
     'divide_med' : False,
     'data_0_1' : False,
     'to_save' : True, 
     'default_path':  r'E:/CODES FROM GITHUB/GraFT-analysis/code/neurofinder.02.00/images', 
     'save_error_iterations': True,
     'max_images':800,
     'dist_init': 'randn',                                                        #'uniform' # 'rand', # Can also be rand, uniform or randn
     'to_sqrt':True,
     'Poisson':{'maxiter': 5, 'miniter':0.01, 'stopcriterion': 3,
                'tolerance': 1e-8, 'alphainit': 1, 'alphamin': 1e-30, 
                'alphamax': 1e30, 'alphaaccept': 1e30, 'logepsilon': 1e-10,
                'saveobjective':True, 'savereconerror':False, 'savecputime':True, 'penalty':'Canonical',
                'savesolutionpath':False, 'truth': False,'alphamethod':1, 'monotone' : 1, }, 
     'sigma_noise': 0.1,                                                          # std of noise to add to dict if all values are zeros
     'grannet_params':{'lambda_grannet': 0.1, 
                       'distance_metric':'Euclidean',
                       'reg_different_nets': 'unified',
                       'num_free_nets':0,
                       'distance2proximity_trans': 'exp',
                       'initialize_by_other_nets':True,                            # True if to initialize A by the other nets
                       'late_start': 5,
                       'include_Ai':False,                                         # whether to include the distance from Ai to itself in NeuroInfer (if True - kind of smoothness regularization)
                       'labels_indicative':False },
     'to_store_lambdas':False,
     'reorder_nets_by_importance': False,
     'uniform_vals': [0,1],
     'distance_metric': 'euclidean' ,
      'graph_params':{'kernel_grannet_type':'combination', 'params_weights': 0.4}, # params_weights should be a vector if kernel_grannet_type is averaged and a scalar if combined
      'save_figs_compare': False,
      'save_plots' : False,
       'compare_to_synth': False
      
    }


params_full = {**params_default, **params_update}

            
num_repeats = 10

for j in range(num_repeats):
    ss = int(str(datetime2.now()).split('.')[-1])
    seed = ss
  
    
    
    
    
    to_apply_noise = True
    max_std = 3
    min_std = 0.01
    min_missing = 0.001
    max_missing = 0.9
    to_check_init = False
    to_missing_samples = False
    auto_scale = False 
    
    if  to_recreate:
        run_GraFT(data =[], corr_kern = [] ,  params = params_default, grannet= True, images = False)
        
    
    
    else:
        if type_grannet.startswith('synth'): 
            params_default['p'] = 10
            if to_apply_noise:
                params_full['noise_params'] = {**params_full['noise_params'], **{ 'apply_noise': True, 'std': np.random.uniform(min_std, max_std)}}
            if to_check_init:
                params_full['init_params'] = {**params_full['init_params'], **{'seed': seed}}
            if to_missing_samples:
                params_full['missing_samples_params'] = {**params_full['missing_samples_params'], **{'seed': seed, 'apply_missing': True, 'perc_missing':np.random.uniform(min_missing, max_missing )}}
                
            full_A, full_phi, additional_return = run_GraFT(data = 'data_synth_grannet_xmin_0_xmax_n_ymin_0_ymax_n.npy', corr_kern = []  ,
                  params = params_full , grannet=True, images = False, data_type = 'synth_grannet')
            
        elif type_grannet == 'trends_grannet':
            full_A, full_phi, additional_return = run_GraFT(data = 'data_trends_grannet_xmin_0_xmax_n_ymin_0_ymax_n.npy', corr_kern = [],  
                  params = params_full , grannet=True, images = False, data_type = 'trends_grannet') 
        elif type_grannet == 'neuro_bump_short_short':
            full_A, full_phi, additional_return = run_GraFT(data = 'data_neuro_bump_short_short_xmin_0_xmax_n_ymin_0_ymax_n.npy', corr_kern = [],  
                  params = params_full , grannet=True, images = False, data_type = 'neuro_bump_short_short') 
                    
                
        else:
            raise ValueError('unknown type_grannet')
        

    