'''
This script includes utility functions for evaluation of the models.
'''

import os
from tqdm import tqdm
import numpy as np
from modules.measures import MeasureCalculator
from train_eval_utils.utils_distance_matrix import get_EUC


class Multi_Evaluation:
    def __init__(self, data, latent):
        self.data = data
        self.latent = latent

    def define_ks(self, dist_mat_X):
        # define k values for evaluation, logarithmically spaced
        k_neighbours = np.unique(np.logspace(1, np.log(min(dist_mat_X.shape[0]/3,200))/np.log(5), num=10, base=5).astype(int))
        return k_neighbours

    def get_multi_evals(self, local=False):
        """
        Performs multiple evaluations for nonlinear dimensionality
        reduction.

        - data: data samples as matrix
        - latent: latent samples as matrix
        - local: whether to use local or global evaluation
        - ks: list of k values for evaluation
        """
        if local:
            indep_measures_list = {'density_kl_global_001':0.,
                                   'density_kl_global_01':0.,
                                   'density_kl_global_1':0.,
                                   'density_kl_global_10':0.}
            dep_measures_list = {'mean_shared_neighbours':0., 
                                 'mean_dist_mrre':0., 
                                 'mean_trustworthiness':0, 
                                 'mean_continuity':0.}
            
            # for UEA N is time series length, for MacroTraffic N is number of nodes, for MicroTraffic N is number of agents
            N = self.data.shape[-2]
            sample_indices = np.arange(self.data.shape[0])
            sample_count = 0
            dist_mat_measure = {'local_distmat_rmse': 0}
            for sample_index in tqdm(sample_indices, desc='Local evaluation', ascii=True, miniters=100):
                data = self.data[sample_index].reshape(N, -1)
                latent = self.latent[sample_index].reshape(N, -1)
                dist_mat_X = get_EUC(data)
                dist_mat_Z = get_EUC(latent)
                if dist_mat_X.max()-dist_mat_X.min() == 0 or dist_mat_Z.max()-dist_mat_Z.min() == 0:
                    continue
                else:
                    dist_mat_X = abs((dist_mat_X - dist_mat_X.min()) / (dist_mat_X.max() - dist_mat_X.min()))
                    dist_mat_Z = abs((dist_mat_Z - dist_mat_Z.min()) / (dist_mat_Z.max() - dist_mat_Z.min()))

                dist_mat_measure['local_distmat_rmse'] += np.sqrt(np.mean((dist_mat_X - dist_mat_Z)**2))

                ks = self.define_ks(dist_mat_X)
                calc = MeasureCalculator(dist_mat_X, dist_mat_Z, max(ks))

                indep_measures = calc.compute_k_independent_measures()
                for key, value in indep_measures.items():
                    indep_measures_list[key] += value
                dep_measures = calc.compute_measures_for_ks(ks)
                mean_dep_measures = {'mean_'+key: np.nanmean(values) for key, values in dep_measures.items()}
                for key, value in mean_dep_measures.items():
                    dep_measures_list[key] += value

                sample_count += 1
                if sample_count >= 500:
                    break
            
            dist_mat_measure['local_distmat_rmse'] /= sample_count
            indep_measures = {'local_'+key: value/sample_count for key, value in indep_measures_list.items()}
            dep_measures = {'local_'+key: value/sample_count for key, value in dep_measures_list.items()}
            results = {**dist_mat_measure, **indep_measures, **dep_measures}
        else:
            N = self.data.shape[0]
            print('Calculating global distance matrix...')
            dist_mat_X = get_EUC(self.data.reshape(N, -1))
            dist_mat_Z = get_EUC(self.latent.reshape(N, -1))
            dist_mat_X = abs((dist_mat_X - dist_mat_X.min()) / (dist_mat_X.max() - dist_mat_X.min()))
            dist_mat_Z = abs((dist_mat_Z - dist_mat_Z.min()) / (dist_mat_Z.max() - dist_mat_Z.min()))
            print('Distance matrix calculated.')

            dist_mat_measure = {'global_distmat_rmse': np.sqrt(np.mean((dist_mat_X - dist_mat_Z)**2))}

            ks = self.define_ks(dist_mat_X)
            calc = MeasureCalculator(dist_mat_X, dist_mat_Z, max(ks))

            indep_measures = calc.compute_k_independent_measures()
            indep_measures = {'global_'+key: value for key, value in indep_measures.items()}
            dep_measures = calc.compute_measures_for_ks(ks)
            mean_dep_measures = {'global_mean_' + key: np.nanmean(values) for key, values in dep_measures.items()}

            results = {**dist_mat_measure, **indep_measures, **mean_dep_measures}
            
        return results


def evaluate(data, labels, model, batch_size, local=False, save_latents=False, save_dir=None):
    # encode data into latent space
    if local:
        latent = model.encode(data, batch_size=batch_size).detach().cpu().numpy() # (N, T, P)
    else:
        latent = model.encode(data, batch_size=batch_size, encoding_window='full_series').detach().cpu().numpy() # (N, P)
        if save_latents:
            np.savez(
                os.path.join(save_dir, 'latents_local.npz' if local else 'latents_global.npz'),
                latents=latent, labels=labels
            )

    # # switch axes to (n_samples, n_timesteps, n_agents, n_features) for MicroTraffic data
    if model.loader == 'MicroTraffic':
        data = data.transpose(0, 2, 1, 3)

    # # add a time feature to each instance in data
    # ranges = np.max(data, axis=1).max(axis=0) - np.min(data, axis=1).min(axis=0)
    # time_range = np.median(ranges) * np.linspace(0, 1, data.shape[1])
    # if data.ndim == 3:
    #     data = np.concatenate((data, np.tile(time_range, (data.shape[0], 1)).reshape(data.shape[0], data.shape[1], 1)), axis=-1)
    # elif data.ndim == 4:
    #     data = np.concatenate([data, np.tile(time_range, (data.shape[0], 1, data.shape[2], 1)).reshape(data.shape[0], data.shape[1], data.shape[2], 1)], axis=-1)

    evaluator = Multi_Evaluation(data, latent)
    ev_result = evaluator.get_multi_evals(local)

    return ev_result

