from datetime import datetime
from glob import glob
import os

import causaldag as cd
import networkx as nx
import numpy as np
import pandas as pd

from utils.config import load_yaml_config
from utils.get_equiv_class import get_equiv_class
from utils.metrics import count_dag_accuracy
from utils.utils import is_dag


metrics = ['fdr', 'tpr', 'fpr', 'shd', 'nnz', 'precision', 'recall', 'f1']


def to_numeric(result_df, except_cols=set()):
    # Postprocessing
    for col in list(result_df):
        if col not in except_cols:
            result_df[col] = pd.to_numeric(result_df[col])


def initialize_result_df(add_configs=[]):
    result_df = pd.DataFrame(columns=['output_dir', 'is_dag', 'minutes', 'nnz_true', 'nnz_equiv_true', 'all_data_shape',
                                      *['{}_dag'.format(metric) for metric in metrics],
                                      *['{}_equiv'.format(metric) for metric in metrics],
                                      *add_configs])
    return result_df


def compute_results(result_path, result_df, add_configs, load_equiv_true=False):
    for i, path in enumerate(glob('{}/*'.format(result_path))):
        config = load_yaml_config('{}/config.yaml'.format(path))
        output_dir = os.path.basename(path)
        try:
            params_true = np.load('{}/params_true.npy'.format(path), allow_pickle=True).item()
            params_est = np.load('{}/params_est.npy'.format(path), allow_pickle=True).item()
            nnz_true = int(params_true['dag'].sum())

            # Get directed edges that we aim to identify
            if load_equiv_true:
                equiv_true = np.load('../../../notebooks/generate_equiv_class/synthetic_data/equiv_true_{}nodes_{}seed.npy'.format(
                    config['nodenum'], config['seed']
                ), allow_pickle=True).item()
            else:
                dag_edges_true = [(u, v) for u, v in np.array(np.where(params_true['dag'])).T.tolist()]
                equiv_true = get_equiv_class(config['nodenum'], dag_edges_true,
                                             params_true['interv_targets'],
                                             params_true['selection_parents'])
            dag_goal_true = np.zeros((config['nodenum'], config['nodenum']))
            for u, v in equiv_true['->']:
                # Edege u - > v
                dag_goal_true[u, v] = 1
            nnz_equiv_true = int(dag_goal_true.sum())

            if config['method_type'] in {'ut_igsp', 'igsp'}:
                g = cd.DAG.from_amat(params_est['dag'])
                cpdag = g.cpdag()
                if config['method_type'] == 'igsp':
                    params_est['targets_list'] = params_true['interv_targets']
                icpdag = g.interventional_cpdag(params_est['targets_list'], cpdag=cpdag)
                dag_est = np.zeros((config['nodenum'], config['nodenum']))
                for u, v in icpdag.arcs:
                    # Edege u - > v
                    dag_est[u, v] = 1
                params_est['dag'] = dag_est
            elif config['method_type'] in {'gies'}:
                g = cd.PDAG.from_amat(params_est['pdag'])
                dag_est = np.zeros((config['nodenum'], config['nodenum']))
                for u, v in g.arcs:
                    # Edege u - > v
                    dag_est[u, v] = 1
                params_est['dag'] = dag_est
            # Calculate DAG and equivalence results
            results_dag = count_dag_accuracy(params_true['dag'], params_est['dag'])
            results_equiv = count_dag_accuracy(dag_goal_true, params_est['dag'])
            minutes = get_total_time('{}/training.log'.format(path))

            # Save to df
            result_df.loc[len(result_df)] = [output_dir, is_dag(params_est['dag']), minutes,
                                             nnz_true, nnz_equiv_true, params_true['all_data_shape'],
                                             *[results_dag.get(metric, None) for metric in metrics],
                                             *[results_equiv.get(metric, None) for metric in metrics],
                                             *[config.get(c, None) for c in add_configs]]
        except Exception as e:
            print("Error for path {}.".format(path))
            print("Error message: {}.".format(e))

            
def get_total_time(log_path):
    try:
        with open(log_path) as f:
            content = f.readlines()
            start = datetime.strptime(content[0][:19], '%Y-%m-%d %H:%M:%S')
            end = datetime.strptime(content[-1][:19], '%Y-%m-%d %H:%M:%S')
            diff = end - start
            return (end - start).seconds / 60    # Minutes
    except:
        return None