from graph_learning.scripts import ScriptConfig, config_dispatch
import pandas as pd
import os
from pathlib import Path
import shutil
from sklearn import preprocessing
import numpy as np

@ScriptConfig.register('metrics-report')
class MetricsReportConfig(ScriptConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

        if self.el is None:
            self.el = os.listdir(Path(self.experiments_dir)/self.es)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--es')
        parser.add_argument('--e', dest='el', nargs='+')
        parser.add_argument('--experiments-dir', default='results/experiments/')
        parser.add_argument('--output-dir', default='results/reports/')
        parser.add_argument('--metrics-path', default='log/test_metrics.csv')

    @staticmethod
    def build(args, context):
        C = MetricsReportConfig(args, context)

        experiments_path = Path(C.experiments_dir)
        es_path = experiments_path / C.es

        es_out_path = Path(C.output_dir) / C.es
        shutil.rmtree(es_out_path, ignore_errors=True)
        es_out_path.mkdir(parents=True)

        for e in C.el:
            e_path = es_path / e

            means_b, sds_b, confs_b = None, None, None

            for conf in os.listdir(e_path):
                metrics_dfs = []
                for version in os.listdir(e_path / conf):
                    metrics_path = e_path / conf / version / C.metrics_path
                    metrics_df = pd.read_csv(metrics_path)
                    metrics_dfs.append(metrics_df)
                metrics_df = pd.concat(metrics_dfs)
                scalar = preprocessing.StandardScaler().fit(metrics_df)

                means = scalar.mean_
                sds = scalar.scale_
                confs = np.array([conf]*means.shape[0])

                if means_b is None:
                    means_b = means
                    sds_b = sds
                    confs_b = confs
                else:
                    update = (means > means_b)
                    means_b = np.where(update, means, means_b)
                    sds_b = np.where(update, sds, sds_b)
                    confs_b = np.where(update, confs, confs_b)

            output_df = pd.DataFrame.from_dict({
                'name': metrics_df.columns,
                'mean': means_b,
                'var': sds_b,
                'config': confs_b,
            })
            output_df.to_csv(es_out_path/f'{e}.csv', index=False)

