# -*- coding: utf-8 -*-
"""
Connectome-informed reservoir - Echo-State Network
=======================================================================
This example demonstrates how to use the conn2res toolbox to implement
perform multiple tasks across dynamical regimes, and using different
types local dynamics
"""
import warnings
import os
import numpy as np
import pandas as pd
from sklearn.base import is_classifier
from conn2res.tasks import NeuroGymTask
from conn2res.connectivity import Conn
from conn2res.reservoir import EchoStateNetwork
from conn2res.readout import Readout
from conn2res import readout, plotting

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

# #####################################################################
# First, let's initialize some constant variables
# #####################################################################
PROJ_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

# project and figure directory
PROJ_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
OUTPUT_DIR = os.path.join(PROJ_DIR, 'figs')
if not os.path.isdir(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# number of runs for each task
N_RUNS = 1000

# number of experiments in total
# N_EXPS = 10

# input projection factor
FACTOR = 0.0001

# input set name
INPUT = 'subctx'

# name of the tasks to be performed
TASKS = [
    'ContextDecisionMaking'
]

# define metrics to evaluate readout's model performance
CLASS_METRICS = [
    'balanced_accuracy_score',
    'f1_score',
]

# define alpha values to vary global reservoir dynamics
ALPHAS = [0.8, 0.9, 0.95, 1.0, 1.1, 1.2]

# select different activation functions to vary local dynamics
ACT_FCNS = [
    'tanh'
]

REGIONS = ['VIS', 'SM', 'DA', 'VA', 'LIM', 'FP', 'DMN']

RSN_MAPPING = '/home/bach/Documents/GitHub/conn2res/examples/data/rsn_mapping/rsn_human_250_DAx1p5.npy'

# Initialize the model outside of constraint to make sure that the input set is consistent
# X, Y = [], []
# for i in range(N_RUNS):
#     # #####################################################################
#     # Second, let's create an instance of a NeuroGym task. To do so we need
#     # the name of task.
#     # #####################################################################
#     x_sub, y_sub = task.fetch_data(n_trials=1000)
#     X.append(x_sub)
#     Y.append(y_sub)

# Initialize the dataset the same number of runs
data_dict_x = {}
data_dict_y = {}

for task_name in TASKS:
    data_dict_x[task_name] = []
    data_dict_y[task_name] = []
    task = NeuroGymTask(name=task_name)
    for i in range(N_RUNS):
        x, y = task.fetch_data(n_trials=1000)
        data_dict_x[task_name].append(x)
        data_dict_y[task_name].append(y)


for task_name in TASKS:
    for output_set in REGIONS:

        output_exp_name = '{}_Struct500_OUT_{}_RIDGE_RD_DAx1p5'.format(task_name, output_set)
        print(f'\n---------------TASK: {output_exp_name.upper()}----------------')

        OUTPUT_DIR = os.path.join(PROJ_DIR, 'figs', output_exp_name)
        if not os.path.exists(OUTPUT_DIR):
            os.makedirs(OUTPUT_DIR)

        # #####################################################################
        # Next, we will simulate the dynamics of the reservoir. We will evaluate
        # the effect of local network dynamics by using different activation
        # functions. We will also evaluate network performance across dynamical
        # regimes by parametrically tuning alpha, which corresponds to the
        # spectral radius of the connectivity matrix (alpha parameter).
        # #####################################################################
        df_subj = []
        for activation in ACT_FCNS:

            print(f'\n------ activation function = {activation} ------')

            df_runs = []
            for run in range(N_RUNS):

                print(f'\n\t\t--- run = {run} ---')

                # fetch data to perform task
                x, y = data_dict_x[task_name][run], data_dict_y[task_name][run]

                # visualize task dataset
                if run == 0:
                    plotting.plot_iodata(
                        x, y, title=task_name, savefig=True,
                        fname=os.path.join(OUTPUT_DIR, f'io_{task_name}_{activation}'),
                        rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
                        show=False
                    )

                # split data into training and test sets
                x_train, x_test, y_train, y_test = readout.train_test_split(x, y)

                # #####################################################################
                # Third, let's import the connectivity matrix we are going to use to
                # define the connections of the reservoir.  For this we will be using
                # the human connectome parcellated into 1015 brain regions following
                # the Desikan  Killiany atlas (Desikan, et al., 2006).
                # #####################################################################

                # load connectivity data of one consensus based on run number
                # filename = '/home/bach/Documents/GitHub/conn2res/examples/data/connectivity/consensus/human_500.npy'
                filename = '/home/bach/Documents/GitHub/suarez_neuromorphicnetworks/raw_results/conn_results/reliability/scale250/consensus_{}.npy'.format(run)
                conn = Conn(subj_id=0, filename=filename)

                # scale conenctivity weights between [0,1] and normalize by spectral its
                # radius
                conn.scale_and_normalize()

                # We will define the set of input and output nodes. To do so, we
                # will use functional intrinsic networks (Yeo ,et al., 2011).
                # input nodes: a random set of brain regions in the visual system
                input_nodes = conn.get_nodes(INPUT, filename=RSN_MAPPING)
                # input_nodes = conn.get_nodes(
                #     'random', nodes_from=conn.get_nodes('VIS', filename=RSN_MAPPING),
                #     n_nodes=task.n_features,
                #     filename=RSN_MAPPING
                # )

                # output nodes: all brain regions in the somatomotor system
                output_nodes = conn.get_nodes(output_set, filename=RSN_MAPPING)

                # create input connectivity matrix to define connections between
                # the input layer (source nodes where the input signal is coming
                # from) and the input nodes of the reservoir.
                # project input signals to every nodes of a subnetwork with a FACTOR 

                w_in = np.zeros((task.n_features, conn.n_nodes))
                # interleaved distribution of features to input nodes (e.g., 1 to node 1, 3; 2 to node 2, 4)
                for n in range(task.n_features):
                    input_par = input_nodes[n::task.n_features]
                    w_in[n, input_par] = FACTOR

                # w_in[:, input_nodes] = FACTOR
                # w_in[:, input_nodes] = np.eye(task.n_features)
                    
                # print(np.nonzero(w_in))
                # print(input_nodes)
                # values, counts = np.unique(np.load(RSN_MAPPING), return_counts=True)
                # print(values)
                # print(counts)
                # print(conn.w.shape)
                

                # instantiate an Echo State Network object
                esn = EchoStateNetwork(w=conn.w, activation_function=activation)

                # instantiate a Readout object
                readout_module = Readout(estimator=readout.select_model(y))

                # defined performance metrics based on Readout's type of model
                metrics = CLASS_METRICS

                # iterate global dynamics using different alpha values
                df_alpha = []
                for alpha in ALPHAS:

                    print(f'\n\t\t\t----- alpha = {alpha} -----')

                    # scale connectivity matrix by alpha
                    esn.w = alpha * conn.w

                    # simulate reservoir states
                    rs_train = esn.simulate(
                        ext_input=x_train, w_in=w_in,
                        output_nodes=output_nodes
                    )

                    rs_test = esn.simulate(
                        ext_input=x_test, w_in=w_in,
                        output_nodes=output_nodes
                    )

                    # # visualize reservoir states
                    if run == 0 and alpha == 0.95:
                        plotting.plot_reservoir_states(
                            x=x_train, reservoir_states=rs_train,
                            title=task_name,
                            savefig=True,
                            fname=os.path.join(OUTPUT_DIR, f'res_states_train_{task_name}_{activation}'),
                            rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
                            show=False
                        )
                        plotting.plot_reservoir_states(
                            x=x_test, reservoir_states=rs_test,
                            title=task_name,
                            savefig=True,
                            fname=os.path.join(OUTPUT_DIR, f'res_states_test_{task_name}_{activation}'),
                            rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
                            show=False
                        )

                    # perform task
                    df_res = readout_module.run_task(
                        X=(rs_train, rs_test), y=(y_train, y_test),
                        sample_weight='both', metric=metrics,
                        readout_modules=None, readout_nodes=None,
                    )

                    # assign column with alpha value and append df_res
                    # to df_alpha
                    df_res['alpha'] = np.round(alpha, 3)
                    df_alpha.append(df_res)

                    # # visualize diagnostic curves
                    if run == 0 and alpha == 0.95 and is_classifier(readout_module.model):
                        plotting.plot_diagnostics(
                            x=x_train, y=y_train, reservoir_states=rs_train,
                            trained_model=readout_module.model, title=task_name,
                            savefig=True,
                            fname=os.path.join(OUTPUT_DIR, f'diag_train_{task_name}_{activation}'),
                            rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
                            show=False
                        )
                        plotting.plot_diagnostics(
                            x=x_test, y=y_test, reservoir_states=rs_test,
                            trained_model=readout_module.model, title=task_name,
                            savefig=True,
                            fname=os.path.join(OUTPUT_DIR, f'diag_test_{task_name}_{activation}'),
                            rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
                            show=False
                        )

                # concatenate results across alpha values and append
                # df_alpha to df_runs
                df_alpha = pd.concat(df_alpha, ignore_index=True)
                df_alpha['run'] = run
                df_runs.append(df_alpha)

            # concatenate results across runs and append
            # df_runs to df_subj
            df_runs = pd.concat(df_runs, ignore_index=True)
            df_runs['activation'] = activation
            if 'module' in df_runs.columns:
                df_subj.append(
                    df_runs[['module', 'n_nodes', 'activation', 'run', 'alpha']
                            + metrics]
                )
            else:
                df_subj.append(df_runs[['activation', 'run', 'alpha'] + metrics])
        # concatenate results across activation functions
        df_subj = pd.concat(df_subj, ignore_index=True)
        df_subj.to_csv(
            os.path.join(OUTPUT_DIR, f'results_{task_name}.csv'),
            index=False
            )
        print(os.path.join(OUTPUT_DIR, f'results_{task_name}.csv'))

        ###########################################################################
        # visualize performance curve
        df_subj = pd.read_csv(
                    os.path.join(OUTPUT_DIR, f'results_{task_name}.csv'),
                    index_col=False
                    )

        for metric in metrics:
            plotting.plot_performance(
                df_subj, x='alpha', y=metric, hue='activation',
                title=task_name, savefig=True,
                fname=os.path.join(OUTPUT_DIR, f'perf_{task_name}_{metric}'),
                rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
                show=False
            )
