"""
For each method, aggregate results for each: scenario, data generating configuration, hyperparameters configuration.  
Results are stored as pairs (data{i}.csv, groundtruth{i}.csv),
where {i} is an integer used as seed identifier.

Metrics
    - tpr
    - fpr
    - tnr
    - fnr
    - f1

We record mean and standards deviation over all the seeds.
Two options are provided
    - Update of existing log with update_log() function
    - Creation of new log.
"""

import os
import json
import argparse
import numpy as np
import pandas as pd
from shutil import copyfile
from utils._metrics import get_metrics, d_top, dtop_fnr
from utils._utils import tunable_parameters, has_order

LOGS_DIR = "/efs/tmp/causal-benchmark-logs"
WORKSPACE="/home/ec2-user/causal-benchmark/tmp/logs"

def args_sanity_check(args):
    # Check scenarios
    allowed_scenarios = ["vanilla", "confounded", "linear", "measure_err", "timino", "unfaithful", "pnl"]
    for scenario in args.scenarios:
        if scenario not in allowed_scenarios:
            raise ValueError(f"Scenario {scenario} not allowed!")
        
    # Check methods
    allowed_methods = [
        "ges", "das", "score", "nogam", "cam", "pc", "diffan", "lingam", "grandag", "resit", "random", "varsort", "scoresort"
    ]
    for method_name in args.methods:
        if method_name not in allowed_methods:
            raise ValueError(f"Method {method_name} not allowed!")
        
    # Check graph_type
    allowed_graphs = ["ER", "SF", "GRP", "FC"]
    if args.graph_type not in allowed_graphs:
        raise ValueError(f"Graph {args.graph_type} not allowed!")


def get_relative_metrics(base_dir, scenario, scenario_param, methods):

    for method in methods:
        # Define the correct hyperparameter of interest
        if method == "ges":
            hyperparam_name = "lambda"
        else:
            hyperparam_name = "alpha"


        # Read raw.csv for vanilla experiments
        raw_vanilla = pd.read_csv(os.path.join(base_dir, "vanilla", "vanilla", method, f"raw_{method}.csv"))

        # Read raw.csv for scenario/scenario_param experiments
        raw_scenario = pd.read_csv(os.path.join(base_dir, scenario, scenario_param, method, f"raw_{method}.csv"))

        # Read stats_csv for scenario/scenario_param experiments
        stats_scenario = pd.read_csv(os.path.join(base_dir, scenario, scenario_param, method, f"stats_{method}.csv"))

        for index in range(len(raw_scenario)):
            f1_scenario = raw_scenario.loc[index, "f1"]
            fnr_pi_scenario = raw_scenario.loc[index, "dtop_fnr"]
            
            # Selectors
            seed_id = raw_scenario.loc[index, "seed_id"]
            noise = raw_scenario.loc[index, "noise"]
            samples = raw_scenario.loc[index, "samples"]
            size = raw_scenario.loc[index, "size"]
            density = raw_scenario.loc[index, "density"]
            hyperparam = raw_scenario.loc[index, hyperparam_name]

            # Select the corresponding vanilla row
            vanilla_row = raw_vanilla.loc["seed_id"] == seed_id and\
                raw_vanilla.loc["noise"] == noise and\
                raw_vanilla.loc["samples"] == samples and\
                raw_vanilla.loc["size"] == size and\
                raw_vanilla.loc["density"] == density and\
                raw_vanilla.loc[hyperparam_name] == hyperparam
            
            f1_vanilla = vanilla_row.loc["f1"]
            fnr_pi_vanilla = vanilla_row.loc["dtop_fnr"]

            # Append relative metrics to the raw_dataframe
            f1_relative = 1 - (f1_scenario/f1_vanilla)
            fnr_pi_relative = 1 - (fnr_pi_scenario/fnr_pi_vanilla)






def create_logs(base_dir, workspace_dir, scenarios, methods, allow_overwrite):
    """
    Take single experiments predictions in base_dir/<scenario>/<method>
    and create logging file base_dir/<scenario>/<method>/<method>.csv
    
    Parameters
    ----------
    base_dir: str
        Base directory with metadata and raw adjacency predictions for each method and scenario
    scenarios : List[str]
        List of inference scenarios
    methods : List[str]:
        List of methods used in the experiments
    allow_overwrite : bool
        If True, allows overwriting existing metadata.csv files
    """
    ######################## Utilities ########################
    def get_order(json_location):
        with open(json_location, "r") as f:
            json_content = json.load(f)
            order = json_content["order"]
        return order

    def is_cpdag(method : str):
        if method in ["ges", "pc"]:
            return True
        return False

    def header(method, aggregate):
        """
        Return header for the logging file base_dir/<noise>/<scenario>/<method>/<method>.csv
        """
        param_name =  tunable_parameters(method)
        header = f"noise,samples,size,density,{param_name},"
        if has_order(method):
            header += "d_top,dtop_fnr,"
        header += "shd,tpr,fpr,tnr,fnr,f1,aupr,time [s]"
        if not aggregate:
            header = "seed_id," + header
        return header

    def format_avg_metric(mean, std):
        return f"{mean} +- {std}"

    def sort_logs(output_file):
        unsorted_df = pd.read_csv(output_file)
        try:
            sorted_df = unsorted_df.sort_values(by=["noise", "samples", "size", "density", tunable_parameters(method), "seed_id"], ascending=[True, True, False, False, True, True])
        except KeyError:
            sorted_df = unsorted_df.sort_values(by=["noise", "samples", "size", "density", tunable_parameters(method)], ascending=[True, True, False, False, True])
        sorted_df.to_csv(output_file, index=False)

    def make_metadata_line(record_file_path, record, method_name):
        reg_param = tunable_parameters(method_name)
        json_content = json.load(record)
        try:
            json_content["hyperparameters"] = json_content["hyperparameters"][reg_param]
        except KeyError:
            json_content["hyperparameters"] = 0 # Placeholder 0 for lingam
        if has_order(method_name):
            json_content["order"] = record_file_path # Store the location for the order
        line_content = ",".join([str(val) for val in json_content.values()])
        return line_content

    def make_metadata(method_logs_dir, method_name, allow_overwrite):
        metadata_logs_dir = os.path.join(method_logs_dir, "tmp")
        metadata_filename = os.path.join(method_logs_dir, "metadata.csv") # TODO: metadata.csv should be passed as argument!
        header = f"seed_id,noise,samples,size,density,num_nodes,{tunable_parameters(method_name)},time [s],pred_location,gt_location"
        if has_order(method_name):
            header += ",order_location"
        header += "\n"

        # Create global metadata.csv if not there
        if (not os.path.exists(metadata_filename)) or allow_overwrite:
            with open(metadata_filename, "w") as f:
                f.write(header)

            # Update metadata.csv with all records
            with open(metadata_filename, "a+") as f:
                for record_file in os.listdir(metadata_logs_dir):
                    record_file_path = os.path.join(metadata_logs_dir, record_file)
                    with open(record_file_path, "r") as record:
                        line = make_metadata_line(record_file_path, record, method_name)
                        f.write(line + "\n")
        return metadata_filename
    
    def make_dir(path):
        if not os.path.exists(path):
            os.makedirs(path)


    ######################## Processing ########################
    for scenario in scenarios:
        scenario_path = os.path.join(base_dir, scenario)
        for scenario_param in os.listdir(scenario_path):
            for method in methods:
                method_logs_dir = os.path.join(base_dir, scenario, scenario_param, method)
                # Make metadata.csv
                metadata_filename = make_metadata(method_logs_dir, method, allow_overwrite)

                # Get name of tuning parameters
                param_name = tunable_parameters(method)
                # Initialize raw_method.csv file and stats_method.csv
                raw_output_file = os.path.join(method_logs_dir, f"raw_{method}.csv")
                with open(raw_output_file, "w") as f:
                    f.write(header(method, aggregate=False) + '\n')
                stats_output_file = os.path.join(method_logs_dir, f"stats_{method}.csv")
                with open(stats_output_file, "w") as f:
                    f.write(header(method, aggregate=True) + '\n')

                # Read the metadata file and store logs
                metadata_df = pd.read_csv(metadata_filename)
                if param_name != "none":
                    metadata_gb = metadata_df.groupby(["noise", "samples", "size", "density", param_name]) # N = sample size
                else:
                    metadata_gb = metadata_df.groupby(["noise", "samples", "size", "density"])

                for group_key in metadata_gb.groups:
                    df = metadata_gb.get_group(group_key).reset_index()
                    group_logs = []
                    for i in range(len(df)):
                        pred_path = df.loc[i, "pred_location"]
                        gt_path = df.loc[i, "gt_location"]
                        pred = np.genfromtxt(pred_path, delimiter=",")
                        gt = np.genfromtxt(gt_path, delimiter=",")
                        if has_order(method):
                            order_file_path = df.loc[i, "order_location"]
                            order = get_order(order_file_path)
                            order_err = d_top(order=order, A=gt)
                            order_fnr = dtop_fnr(dtop=order_err, A=gt)
                            shd, tpr, fpr, tnr, fnr, f1, aupr = get_metrics(pred, gt)
                            metrics = [order_err, order_fnr, shd, tpr, fpr, tnr, fnr, f1, aupr, df.loc[i, "time [s]"]]
                        else:
                            shd, tpr, fpr, tnr, fnr, f1, aupr = get_metrics(pred, gt)
                            metrics = [shd, tpr, fpr, tnr, fnr, f1, aupr, df.loc[i, "time [s]"]]
                        group_logs.append(metrics)
                        
                        seed_id = df.loc[i, "seed_id"]
                        if param_name != "none":
                            noise, num_samples, graph_size, graph_density, param_value = group_key
                        else:
                            noise, num_samples, graph_size, graph_density = group_key
                            param_value = 0

                        with open(raw_output_file, "a") as f:
                                f.write(",".join([str(val) for val in [seed_id, noise, num_samples, graph_size, graph_density, param_value] + metrics]) + "\n")

                    # Write aggregatd logs
                    metrics_mean_std = [
                            format_avg_metric(m, s) for m, s in zip(np.round(np.mean(group_logs, axis=0), 2), np.round(np.std(group_logs, axis=0), 2))
                    ]

                    if param_name != "none":
                        noise, num_samples, graph_size, graph_density, param_value = group_key
                    else:
                        noise, num_samples, graph_size, graph_density = group_key
                        param_value = 0

                    with open(stats_output_file, "a") as f:
                        f.write(",".join([str(val) for val in [noise, num_samples, graph_size, graph_density, param_value] + metrics_mean_std]) + "\n")
                
                # Sort outputs
                sort_logs(raw_output_file)
                sort_logs(stats_output_file)

                # Copy files into workspace directories
                make_dir(os.path.join(workspace_dir, scenario, scenario_param, method))
                raw_copy_path = os.path.join(workspace_dir, scenario, scenario_param, method, f"raw_{method}.csv")
                stats_copy_path = os.path.join(workspace_dir, scenario, scenario_param, method, f"stats_{method}.csv")
                copyfile(raw_output_file, raw_copy_path)
                copyfile(stats_output_file, stats_copy_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Experiments logs aggregation")
    parser.add_argument(
        "--task",
        default="inference",
        type=str,
        help="Whether to cnsider inference or tuning experiments. Accepted values ['inference', 'tuning', 'standardized']"
    )
    parser.add_argument(
        "--graph_type",
        default="ER",
        type=str,
        help="Algorithm for generation of synthetic graphs. Accepted values are ['ER', 'SF', 'GRP', 'FC']"
    )
    parser.add_argument(
        '--methods', 
        nargs='+', 
        help ='Mehods for which visualization is required', 
        type=str
    )
    parser.add_argument(
        '--scenarios', 
        nargs='+', 
        help ='Scenarios for which to aggregate data', 
        type=str
    )
    parser.add_argument(
        '--not_overwrite_meta', 
        action="store_false",
        help ='If True, allow explicit overwrite of existing metadata', 
    )

    args = parser.parse_args()
    if args.methods is None:
        args.methods = [
            "ges", 
            "pc",
            "das",
            "score",
            "cam",
            "nogam",
            "diffan",
            "grandag",
            "lingam",
            "resit",
            "random",
            "scoresort"
        ]

    if args.scenarios is None:
        args.scenarios = [
            "confounded",
            "vanilla",
            "linear",
            "measure_err",
            "timino",
            "unfaithful",
            "pnl"
        ]

    args_sanity_check(args)
    create_logs(os.path.join(LOGS_DIR, args.task, args.graph_type), os.path.join(WORKSPACE, args.task, args.graph_type), args.scenarios, args.methods, args.not_overwrite_meta)