import hydra
import os
import pandas as pd
import json

from syntheseus.search.graph.and_or import AndNode
from syntheseus.search.analysis.route_extraction import (
    iter_routes_time_order,
)

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import compute_diversity, load_search_graph


# think of diversity, how are the routes extracted?
# metrics of interest: solved, diversity/number of unique routes, 
## statistics on route length (vs ground truth length), 
## ground truth round trip/rxn name, generated route round trip/rxn name, 
# generated route class accuracy
# route topk accuracy, route similarity to ground truth

def compute_metrics_per_dataset(experiment_name, route_start_idx, route_end_idx):

    metrics_per_dataset = {
        'diversity': 0,
        'solved': 0,
        'num_unique_routes_extracted': 0,
        'ground_truth_matches': 0,
        'round_trip_matches': 0,
        'classifier_to_true_class_matches': 0,
        'rxn_insight_to_true_class_matches': 0,
        'classifier_to_rxn_insight_class_matches': 0,
    }
    all_metrics_per_target = []
    for route_idx in range(route_start_idx, route_end_idx):
        # read the results files under 'routes evaluation'
        out_dir = os.path.join(PROJECT_ROOT,
                               'experiments', 
                                experiment_name,
                               f'graphs_for_mol{route_idx}', 
                               'route_evaluations')
        df_files = [f for f in os.listdir(out_dir) if f.endswith('.csv')]
        metrics_per_target = {
            'lengths': [],
            'diversity': [],
            'solved': None,
            'num_unique_routes_extracted': None,
            'ground_truth_matches': [], # how many of the reactions in the route match the ground truth
            'round_trip_matches': [],
            'classifier_to_true_class_matches': [],
            'rxn_insight_to_true_class_matches': [],
            'classifier_to_rxn_insight_class_matches': []
            #'time_to_first_solution': None
        }
        metrics_per_target['solved'] = len(df_files) > 0
        metrics_per_target['num_unique_routes_extracted'] = len(df_files)
        for df_file in df_files:
            evaluation_df = pd.read_csv(os.path.join(out_dir, df_file))
            # compute the metrics
            metrics_per_target['lengths'].append(evaluation_df.shape[0])
            assert evaluation_df['diversity'].nunique() == 1, f'diversity is not unique for {df_file}'
            metrics_per_target['diversity'].append(evaluation_df['diversity'].unique()[0])
            evaluation_df['topk_1'] = evaluation_df['topk'].apply(lambda x: eval(x)[1])
            evaluation_df['round_trip_accuracy'] = evaluation_df['round_trip_accuracy'].apply(lambda x: eval(x)[10])
            metrics_per_target['ground_truth_matches'].append(evaluation_df['topk_1'].mean())
            metrics_per_target['round_trip_matches'].append(evaluation_df['round_trip_accuracy'].mean())
            metrics_per_target['classifier_to_true_class_matches'].append((evaluation_df['classifier_output'] == evaluation_df['true_class']).mean())
            metrics_per_target['rxn_insight_to_true_class_matches'].append((evaluation_df['rxn_insight_class'] == evaluation_df['true_class']).mean())
            metrics_per_target['classifier_to_rxn_insight_class_matches'].append((evaluation_df['classifier_output'] == evaluation_df['rxn_insight_class']).mean())
        # compute the average metrics
        for key, value in metrics_per_target.items():
            if key not in ['solved', 'num_unique_routes_extracted']:
                metrics_per_target[f'{key}_mean'] = sum(value) / len(value)
        # save the metrics
        with open(os.path.join(out_dir, 'metrics_per_target.json'), 'w') as f:
            json.dump(metrics_per_target, f)
        all_metrics_per_target.append(metrics_per_target)

    # aggregate the metrics per target
    for metrics in all_metrics_per_target:
        metrics_per_dataset['diversity'] += metrics['diversity']/len(all_metrics_per_target)
        metrics_per_dataset['solved'] += metrics['solved']/len(all_metrics_per_target)
        metrics_per_dataset['num_unique_routes_extracted'] += metrics['num_unique_routes_extracted']/len(all_metrics_per_target)
        metrics_per_dataset['ground_truth_matches'] += metrics['ground_truth_matches']/len(all_metrics_per_target)
        metrics_per_dataset['round_trip_matches'] += metrics['round_trip_matches']/len(all_metrics_per_target)
        metrics_per_dataset['classifier_to_true_class_matches'] += metrics['classifier_to_true_class_matches']/len(all_metrics_per_target)
        metrics_per_dataset['rxn_insight_to_true_class_matches'] += metrics['rxn_insight_to_true_class_matches']/len(all_metrics_per_target)
        metrics_per_dataset['classifier_to_rxn_insight_class_matches'] += metrics['classifier_to_rxn_insight_class_matches']/len(all_metrics_per_target)

    print(f'metrics_per_dataset: {metrics_per_dataset}')

    return metrics_per_dataset, all_metrics_per_target


@hydra.main(config_path='../configs', config_name='config.yaml')
def main(config):
    guided_metrics_per_dataset, guided_all_metrics_per_target = compute_metrics_per_dataset(experiment_name='guided', 
                                                     route_start_idx=0, 
                                                     route_end_idx=135)
    unguided_metrics_per_dataset, unguided_all_metrics_per_target = compute_metrics_per_dataset(experiment_name='unguided', 
                                                     route_start_idx=0, 
                                                     route_end_idx=135)
    # TODO: compare the two datasets


if __name__ == '__main__':
    main()