import cdt
import os
import glob
import numpy as np
import pandas as pd
from modules.utils import get_data, edge_errors, num_errors, recall, precision
from modules.stein import graph_inference

base_folder = "/home/francescom/Research/DAS-Extension/src"

class Experiment:
    def __init__(self, **kwargs):
        self.algorithm = kwargs["algorithm"]
        self.n_runs = kwargs['num_runs']
        self.N = kwargs['N']
        self.GP = kwargs['GP']
        self.graph_type = kwargs['graph_type']
        self.noise_type = kwargs['noise_type']
        self.noise_std = kwargs['nstd']
        self.pruning=kwargs['pruning']
        self.output_path = kwargs['output_path']
        self.columns = [
            'V', 'E', 'N', 'alpha', 'gamma', 'cv splits', 'fn', 'fp', 'reversed', 
            'precision', 'recall', 'SHD', 'SID' , 'D_top', 'Order time [s]', 'Total time [s]'
        ] 
        self.verbose = True

    def get_params(self, params):
            d = params['d']
            s0 = d
            if params['s0'] == '4d':
                s0 = s0*4
            alpha = params['alpha']
            gamma = params['gamma']
            n_cv = params['n_cv']

            return d, s0, alpha, gamma, n_cv

    def get_params_logs(self, d, s0, alpha, gamma, n_cv):
        mean_log, std_log = self.run_params(d, s0, alpha, gamma, n_cv)
        mean_log = np.around(mean_log, decimals=2)
        std_log = np.around(std_log, decimals=2)
        params_logs = [d, s0, self.N, alpha, gamma, n_cv] # A row in the pandas dataframe
        for i in range(6, mean_log.shape[0]):
            log_element = f'{mean_log[i]} +- {std_log[i]}'
            params_logs.append(log_element)

        return params_logs

    def run_experiment(self, params_grid):
        """Run n_run experiments for each combination of parameters and log the results"""
        logs = []
        for params in params_grid:
            d, s0, alpha, gamma, n_cv = self.get_params(params)
            params_logs = self.get_params_logs(d, s0, alpha, gamma, n_cv)
            logs.append(params_logs)

            df_logs = pd.DataFrame(data=logs, columns=self.columns)
            df_logs.to_csv(self.output_path, mode='w')


    def clean_csv(self):
        file_list = glob.glob(f"{base_folder}/tmp*.csv")
        for f in file_list:
            os.remove(f)


    def run_params(self, d, s0, alpha, gamma, n_cv):
        """Run over a single combination of parameters n_runs times"""
        logs = []
        more_logs_params = [d, s0, self.N, alpha, gamma, n_cv]
        for i in range(self.n_runs):
            self.clean_csv()
            print(f"Iteration {i+1}/{self.n_runs}")
            metrics = list(self.single_run(d, s0, alpha, gamma, n_cv))
            logs.append(more_logs_params + metrics)
        
        return np.mean(logs, axis=0), np.std(logs, axis=0)

    def single_run(self, d, s0, alpha, gamma, n_cv):
        """Single run out of n_runs for a  combination of parameters"""
        X, R, A_truth = get_data(
            graph_type=self.graph_type, d=d, s0=s0, N=self.N, GP=self.GP, noise_type=self.noise_type, noise_std=self.noise_std, verbose=self.verbose
        )
        A_pred, top_order, order_time, tot_time = graph_inference(X, alpha=alpha, gamma=gamma, n_cv=n_cv, pruning=self.pruning, algorithm=self.algorithm)
        
        compute_sid=True
        if d > 200:
            compute_sid = False
        metrics = self.metrics(A_truth, A_pred, top_order, compute_sid)
        metrics.append(order_time)
        metrics.append(tot_time)

        if self.verbose:
            print(self.pretty_evaluate(metrics, d, s0, alpha, gamma, n_cv))

        return metrics


    def metrics(self, A_truth, A_pred, top_order, compute_SID=True):
            fn, fp, rev = edge_errors(A_pred, A_truth)
            SHD = sum((fn, fp, rev))
            top_order_err = num_errors(top_order, A_truth)

            n_edges = A_truth.sum()
            p = precision(n_edges, fn, fp)
            r = recall(n_edges, fn)

            SID = -1
            if compute_SID:
                SID = int(cdt.metrics.SID(target=A_truth, pred=A_pred))
            metrics = [fn, fp, rev, p, r, SHD, SID, top_order_err]
            return metrics