from copy import deepcopy, copy

import pandas as pd
import torch
from relnet.utils.config_utils import local_np_seed

from itertools import product
import numpy as np

def generate_search_space(parameter_grid,
                          random_search=False,
                          random_search_num_options=20,
                          random_search_seed=42):
    combinations = list(product(*parameter_grid.values()))
    search_space = {i: combinations[i] for i in range(len(combinations))}

    if random_search:
        if not random_search_num_options > len(search_space):
            reduced_space = {}
            with local_np_seed(random_search_seed):
                random_indices = np.random.choice(len(search_space), random_search_num_options, replace=False)
                for random_index in random_indices:
                    reduced_space[random_index] = search_space[random_index]
            search_space = reduced_space
    return search_space


def construct_search_spaces(experiment_conditions):
    parameter_search_spaces = {}
    relevant_agents = experiment_conditions.relevant_agents
    objective_functions = experiment_conditions.objective_functions

    for obj_fun in objective_functions:
        parameter_search_spaces[obj_fun.name] = {}
        for agent in relevant_agents:
            if agent.algorithm_name in experiment_conditions.hyperparam_grids[obj_fun.name]:
                agent_grid = experiment_conditions.hyperparam_grids[obj_fun.name][agent.algorithm_name]
                combinations = list(product(*agent_grid.values()))
                search_space = {}
                for i in range(len(combinations)):
                    k = str(i)
                    v = dict(zip(list(agent_grid.keys()), combinations[i]))
                    search_space[k] = v
                parameter_search_spaces[obj_fun.name][agent.algorithm_name] = search_space

    return parameter_search_spaces


def score_predictions(predictions, gt_values,
         evaluation_metric='mse'):

    if type(predictions) is torch.Tensor:
        predictions = predictions.clone().cpu().detach().numpy()

    # print(f"predictions {predictions}")
    # print(f"gts {gt_values}")

    if evaluation_metric == 'mse':
        return np.mean((predictions - gt_values) ** 2)
    elif evaluation_metric == 'mae':
        return np.mean(np.abs(predictions - gt_values))
    else:
        raise ValueError(f"unknown eval metric {evaluation_metric}.")

def eval_pred_scores(results):
    from relnet.io.storage import EvaluationStorage
    prediction_data = EvaluationStorage.result_dicts_to_np_arrays(results)
    out = []
    for algo, (preds, gts) in prediction_data.items():
        for metric in ["mse", "mae"]:
            perf = score_predictions(preds, gts, evaluation_metric=metric)
            out.append({"algorithm": algo, "metric": metric, "perf": perf})
    return out


def print_pred_scores(results):
    from relnet.io.storage import EvaluationStorage
    prediction_data = EvaluationStorage.result_dicts_to_np_arrays(results)
    for algo, (preds, gts) in prediction_data.items():
        print(f"=" * 20)
        print(f"Method <<{algo}>>:")
        for metric in ["mse", "mae"]:
            perf = score_predictions(preds, gts, evaluation_metric=metric)
            print(f"{metric.upper()}: {perf}")
        print(f"=" * 20)

def get_model_seed(run_number):
    return int(run_number * 42)

def get_run_number(model_seed):
    return int(model_seed / 42)

def find_max_property_in_list(graph_list, fn_to_apply):
    max_prop = float("-inf")
    for g in graph_list:
        max_prop = max(max_prop, fn_to_apply(g))

    return max_prop

def find_max_nodes(graph_list):
    return find_max_property_in_list(graph_list, lambda g: g.num_nodes)

def find_max_edges(graph_list):
    return find_max_property_in_list(graph_list, lambda g: g.num_edges)

def find_max_diameter(graph_list):
    return find_max_property_in_list(graph_list, lambda g: g.get_diameter())

def find_max_colors(graph_list):
    return find_max_property_in_list(graph_list, lambda g: g.num_colors)


def get_results_table(storage, exp_ids, filter_best_demand_rep=True, leave_out=None):
    alg_filter_out = ["edgeapprox", "random", "predict_median", "labelorder"]  # , "uniqueedge"]
    if leave_out is not None:
        alg_filter_out.extend(leave_out)
    metrics_keep = ["mse"]

    rows = []
    for exp_id in exp_ids:
        results = storage.get_evaluation_data(exp_id)
        results = [r for r in results if r["is_best_hyps"]]

        print(f"exp_id {exp_id}")

        deets = storage.get_experiment_details(exp_id)
        seeds = deets['experiment_conditions']['experiment_params']['model_seeds']
        for seed in seeds:
            seed_results = [r for r in results if r["agent_seed"] == seed]

            graph_name = exp_id.split("_")[0]
            scores = eval_pred_scores(seed_results)
            scores = [score for score in scores if score["metric"] in metrics_keep]
            for score in scores:
                score["graph_name"] = graph_name
                score["routing_model"] = exp_id.split("_")[1]
                score["agent_seed"] = seed
                if "dms_mult" in deets['experiment_conditions']:
                    score["dms_mult"] = deets['experiment_conditions']["dms_mult"]
                else:
                    score["dms_mult"] = 1.
            rows.extend([score for score in scores if not any([alg in score['algorithm'] for alg in alg_filter_out])])

    rows_df = pd.DataFrame(rows)

    if filter_best_demand_rep:
        best_demand_reps = {}
        # todo: compute means for raw vs sum; collect best scores in a new array
        sum_agents = sorted(list(set([a for a in rows_df['algorithm'] if "sum" in a])))
        raw_agents = sorted(list(set([a for a in rows_df['algorithm'] if "raw" in a])))

        for i in range(len(sum_agents)):
            sa = sum_agents[i]
            ra = raw_agents[i]
            arch = sa.split("_")[0]
            demand_rep = sa.split("_")[2]

            rel_data = rows_df[(rows_df["algorithm"] == sa) | (rows_df["algorithm"] == ra)]

            means = pd.pivot_table(rel_data, values='perf', index=["routing_model", "graph_name"], columns=["algorithm"])
            for row in means.itertuples():
                idx = getattr(row, 'Index')
                sum_perf = getattr(row, sa)
                raw_perf = getattr(row, ra)
                which_best = sa if sum_perf < raw_perf else ra

                idx = tuple([arch] + list(idx) + [demand_rep])
                best_demand_reps[idx] = which_best

        clean_rows = []
        for row in rows:
            might_filter = ('sum' in row['algorithm'] or 'raw' in row['algorithm'])
            if might_filter:
                alg_name_parts = row['algorithm'].split("_")
                arch = alg_name_parts[0]

                demand_rep = alg_name_parts[2]
                row_idx = (arch, row['routing_model'], row['graph_name'], demand_rep)
                if best_demand_reps[row_idx] == row['algorithm']:
                    # print(f"adding!")
                    row_cp = copy(row)
                    row_cp['algorithm'] = "_".join([alg_name_parts[0], alg_name_parts[2]])
                    clean_rows.append(row_cp)
            else:
                clean_rows.append(row)

        rows_df = pd.DataFrame(clean_rows)

    return rows_df


def normalize_results(plot_df, how_norm='divide'):
    if "dms_mult" in plot_df.columns:
        print(f"normalizing WITH dm multiplier.")
    else:
        print(f"normalizing WITHOUT dm multiplier.")

    norm_rows = []
    for row in plot_df.itertuples():
        if row.algorithm == "predict_mean":
            continue

        if "dms_mult" in plot_df.columns:
            norm_factor = plot_df[(plot_df['routing_model'] == row.routing_model) &
                                  (plot_df['graph_name'] == row.graph_name) &
                                  (plot_df['algorithm'] == "predict_mean") &
                                  (plot_df['agent_seed'] == row.agent_seed) &
                                  (plot_df['dms_mult'] == row.dms_mult)].iloc[0]["perf"]
        else:
            norm_factor = plot_df[(plot_df['routing_model'] == row.routing_model) &
                                  (plot_df['graph_name'] == row.graph_name) &
                                  (plot_df['algorithm'] == "predict_mean") &
                                  (plot_df['agent_seed'] == row.agent_seed)].iloc[0]["perf"]

        row_cp = row._asdict()
        if how_norm == 'divide':
            row_cp['perf'] = row_cp['perf'] / norm_factor
        else:
            row_cp['perf'] = norm_factor - row_cp['perf']
        norm_rows.append(row_cp)

    plot_df = pd.DataFrame(norm_rows)
    return plot_df

















































