import os
import json
from datetime import datetime
from collections import defaultdict
from itertools import chain

from GOOD import config_summoner
from GOOD.kernel.pipeline_manager import load_pipeline
from GOOD.ood_algorithms.ood_manager import load_ood_alg
from GOOD.utils.logger import load_logger
from GOOD.utils.metric import assign_dict
from GOOD.utils.loader import initialize_model_dataset
from GOOD.networks.models.DIRGNN import split_graph_node
import GOOD.kernel.pipelines.xai_metric_utils as xai_utils
from GOOD.definitions import ROOT_DIR

import numpy as np
import torch

from torch_geometric.utils import to_networkx, from_networkx, to_undirected
from torch_geometric.data import Batch
from torch_scatter import scatter_mean, scatter_std, scatter_min, scatter_max, scatter_add

import matplotlib.pyplot as plt



def print_metric(name, data, results_aggregated=None, key=None):
    avg = np.nanmean(data, axis=0)
    std = np.nanstd(data, axis=0)
    print(f"{name:<25}", " = ", ", ".join([f"{avg[i]:.3f} +- {std[i]:.3f}" for i in range(len(avg))]))
    if not results_aggregated is None:
        assign_dict(
            results_aggregated,
            key,
            avg.tolist()
        )
        key[-1] += "_std" # add _std to the metric name
        assign_dict(
            results_aggregated,
            key,
            std.tolist()
        )


def generate_panel(args):
    load_splits = ["id"]
    for l, load_split in enumerate(load_splits):
        print("\n\n" + "-"*50)

        edge_scores_seed = []
        for i, seed in enumerate(args.seeds.split("/")):
            print(f"GENERATING PLOT FOR LOAD SPLIT = {load_split} AND SEED {seed}\n\n")
            seed = int(seed)        
            args.random_seed = seed
            args.exp_round = seed
            
            config = config_summoner(args)
            config["task"] = "test"
            config["load_split"] = load_split
            if l == 0 and i == 0:
                load_logger(config)
            
            model, loader = initialize_model_dataset(config)
            ood_algorithm = load_ood_alg(config.ood.ood_alg, config)
            pipeline = load_pipeline(config.pipeline, config.task, model, loader, ood_algorithm, config)
            pipeline.load_task(load_param=True, load_split=load_split) 

            edge_scores = pipeline.generate_panel()
            edge_scores_seed.append(edge_scores)
        pipeline.generate_panel_all_seeds(edge_scores_seed)

def evaluate_metric(args):
    if args.splits != "":
        splits = args.splits.split("/")
    else:
        splits = ["id_val"] #"id_val", "val", "test"
    print("Using splits = ", splits)
        
    if args.ratios != "":
        thrs = [float(r) for r in args.ratios.split("/")]
    else:
        thrs = [0.5]
    print("Using thresholds = ", thrs)
    
    startTime = datetime.now()
    metrics_score = {s: defaultdict(list) for s in splits}
    for i, seed in enumerate(args.seeds.split("/")):
        ##
        # SET UP THE CONFIGURATION AND LOAD MODEL
        ##
        seed = int(seed)        
        args.random_seed = seed
        args.exp_round = seed
        
        config = config_summoner(args)
        config["task"] = "test"
        config["load_split"] = ""
        if i == 0:
            load_logger(config)
        
        model, loader = initialize_model_dataset(config)
        ood_algorithm = load_ood_alg(config.ood.ood_alg, config)
        pipeline = load_pipeline(config.pipeline, config.task, model, loader, ood_algorithm, config)
        pipeline.load_task(load_param=True, load_split="id")
        
        ##
        # GENERATE BINARY EXPLANATION MASKS
        ##
        samples, graphs_nx, avg_graph_size  = pipeline.generate_binary_explanations(
            is_weight=True, 
            thrs=thrs,
            splits=splits,
            convert_to_nx="interven_suff" in args.metrics,
            is_node_expl=not config.ood.extra_param[0] # is the learn_edge_att parameter
        )            

        for metric in args.metrics.split("/"):
            print(f"\n\nEvaluating {metric.upper()} for seed {seed}\n")

            for split in splits:
                for thr in thrs:
                    print(f"\n\n#D#Computing {metric.upper()} over {split} across ratios with thr={thr}")
                    score, acc_int = pipeline.compute_metric(
                        metric=metric,
                        graphs=samples[split][thr],
                        graphs_nx=graphs_nx[split],
                        avg_graph_size=avg_graph_size[split],
                    )
                    metrics_score[split][metric].append(score)
                    metrics_score[split][metric + "_acc_int"].append(acc_int)


    if config.save_metrics:
        save_path = f"storage/metric_results/aggregated_id_results_necalpha{config.nec_alpha_1}" \
                    f"_numsamples{config.numsamples_budget}_randomexpl{config.random_expl}_ratios{args.ratios.replace('/','-')}" \
                    f"_metrics{args.metrics.replace('/','-')}" \
                    f"_{config.log_id}.json"
        if not os.path.exists(save_path):
            with open(save_path, 'w') as file:
                file.write("{}")
        with open(save_path, "r") as jsonFile:
            results_aggregated = json.load(jsonFile)
    else:
        results_aggregated = None

    print("\n\n", "-"*50, f"\nPrinting evaluation results\n\n")
    for split in splits:
        print(f"\nEval split {split.upper()}")
        for metric in args.metrics.split("/"):
            print(f"{metric} = {metrics_score[split][metric]}")

    print("\n\n", "-"*50, "\nPrinting evaluation averaged per seed")
    for split in splits:
        print(f"\nEval split {split.upper()}")
        for metric in args.metrics.split("/"):
            for div in ["predicted"]:
                for c in range(10):
                    if f"{c}_{div}" not in metrics_score[split][metric][i].keys():
                        continue
                    # take values acorss seed, then print them
                    s = [
                        metrics_score[split][metric][i][f"{c}_{div}"] for i in range(len(metrics_score[split][metric]))
                    ]
                    print_metric(metric + f" class {c}_{div}", s, results_aggregated, key=[config.dataset.dataset_name + " " + config.dataset.domain, config.complete_dirname, split, metric+f"_{div}"])    
                s = [
                    metrics_score[split][metric][i][f"all_{div}"] for i in range(len(metrics_score[split][metric]))
                ]
                print_metric(metric + f" class all_{div}", s, results_aggregated, key=[config.dataset.dataset_name + " " + config.dataset.domain, config.complete_dirname, split, metric+f"_{div}"])
            print_metric(metric + "_acc_int", metrics_score[split][metric + "_acc_int"], results_aggregated, key=[config.dataset.dataset_name+" "+config.dataset.domain, config.complete_dirname, split, metric+"_acc_int"])
            
            s = [
                metrics_score[split][metric][i][f"rejection"] for i in range(len(metrics_score[split][metric]))
            ]
            print_metric(metric + f" rejection", s, results_aggregated, key=[config.dataset.dataset_name + " " + config.dataset.domain, config.complete_dirname, split, metric+f"_rejection"])
            print()   
    
    print("\nCompleted in ", datetime.now() - startTime, f" for {config.complete_dirname} {config.dataset.dataset_name}/{config.dataset.domain}")
    print("\n\n")

def print_r_ge_b_hist(args):
    load_splits = ["id"]
    for l, load_split in enumerate(load_splits):
        print("\n\n" + "-"*50)

        edge_scores_seed = []
        for i, seed in enumerate(args.seeds.split("/")):
            print(f"GENERATING PLOT FOR LOAD SPLIT = {load_split} AND SEED {seed}\n\n")
            seed = int(seed)
            args.random_seed = seed
            args.exp_round = seed
            
            config = config_summoner(args)
            config["task"] = "test"
            config["load_split"] = load_split
            if l == 0 and i == 0:
                load_logger(config)
            
            model, loader = initialize_model_dataset(config)
            ood_algorithm = load_ood_alg(config.ood.ood_alg, config)
            pipeline = load_pipeline(config.pipeline, config.task, model, loader, ood_algorithm, config)
            pipeline.load_task(load_param=True, load_split=load_split) 

            if config.dataset.dataset_name in ("BAColor", "BAColorGV", "BAColorGVIsolated"):
                print(f"\n\nClassifier weights:")
                print(model.classifierS.classifier[0].weight.detach()) #, model.classifier.classifier[0].bias.detach()

            # GET EXPLANATIONS
            ret = pipeline.get_node_explanations()

            # AGGREGATE INFO BY LABEL
            list_of_labels = np.array([ret["id_val"]["samples"][i].y.item() for i in range(len(ret["id_val"]["samples"]))])
            list_of_colors = {l: [] for l in np.unique(list_of_labels)}
            count_of_relevant_colors = {l: defaultdict(list) for l in np.unique(list_of_labels)}
            list_of_scores = {l: [] for l in np.unique(list_of_labels)}

            if "DIR" in config.model.model_name:
                importance_threshold = 0.05
            else:
                importance_threshold = 0.5

            for i, label in enumerate(np.unique(list_of_labels)):
                for j in range(len(ret["id_val"]["samples"])):
                    if ret["id_val"]["samples"][j].y.item() == label:
                        node_colors = list(
                            map( # Convert list of tuples into list of colors
                                loader["id_val"].dataset.color_map2.get,
                                [int(r) for r in ret["id_val"]["samples"][j].x.argmax(dim=1)] # convert feature matrix into list of tuples [(1., 0.)]
                            )
                        )
                        list_of_colors[label].extend(node_colors)
                        list_of_scores[label].extend(
                            ret["id_val"]["scores"][j]
                        )

                        important_colors_count = np.unique(
                            np.array(node_colors)[np.array(ret["id_val"]["scores"][j]) >= importance_threshold],
                            return_counts=True
                        )

                        for color in np.unique(node_colors):
                            if color in important_colors_count[0]:
                                count_of_relevant_colors[label][color].append(important_colors_count[1][important_colors_count[0] == color][0])
                            else:
                                count_of_relevant_colors[label][color].append(0)

                        # if label == 1.0 and count_of_relevant_colors[label]["R"] >= count_of_relevant_colors[label]["B"]:
                        #     print(i)

                # average the count of relevant colors
                # for c in count_of_relevant_colors[label].keys():
                #     count_of_relevant_colors[label][c] /= sum(loader["id_val"].dataset.y == label).item()

                list_of_colors[label] = np.array(list_of_colors[label])
                list_of_scores[label] = np.array(list_of_scores[label])

            
            # PLOT HISTOGRAMS
            n_row = 2
            n_col = np.unique(list_of_colors[0]).shape[0] + 1
            fig, axs = plt.subplots(n_row, n_col, figsize=(12,7))

            for i, label in enumerate(np.unique(list_of_labels)):
                print(f"\ny={int(label)}")
                print(f"\tStats of node scores: min={min(list_of_scores[label]):.3f}, max={max(list_of_scores[label]):.3f}, avg={np.mean(list_of_scores[label]):.3f}, std={np.std(list_of_scores[label]):.3f} ")
                print(f"\tAverage count of relevant colors:", {c: round(np.mean(count_of_relevant_colors[label][c]), 2) for c in count_of_relevant_colors[label].keys()})
                
                # per-color hist
                for c, color in enumerate(["R", "B", "G", "V"]):
                    axs[i,c].hist(
                        list_of_scores[label][list_of_colors[label] == color] + np.random.normal(0.0, scale=0.005, size=list_of_scores[label][list_of_colors[label] == color].shape),
                        density=True,
                        log=False,
                        bins=100,
                        label=color
                    )
                    if "DIR" in config.model.model_name:
                        axs[i,c].set_xlim(-1.1, 1.1)
                    else:
                        axs[i,c].set_xlim(-0.1, 1.1)
                    axs[i,c].set_ylim(0.0, 100)
                    axs[i,c].set_title(f"color {color}")
                    if c == 0:
                        axs[i,0].set_ylabel(f"y={int(label)}")
                
                # per-sample boxplot
                axs[i, -1].set_title(f"per-sample avg count")
                bplot = axs[i, -1].boxplot([
                        count_of_relevant_colors[label][c] for c in ["R", "B", "G", "V"]
                    ],
                    patch_artist=True,
                    labels=["R", "B", "G", "V"],
                    showfliers=False,
                    showmeans=True
                )
                for patch, color in zip(bplot['boxes'], ["red", "blue", "green", "violet"]):
                    patch.set_facecolor(color)
                axs[i, -1].axhline(y=0, color='red', linestyle='--', linewidth=1, alpha=0.5)

            fig.supxlabel('explanation relevance scores', fontsize=13)
            fig.supylabel('density', fontsize=13)
            fig.suptitle(f'{config.model.model_name} seed {seed}', fontsize=13)
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            
            path = f'{ROOT_DIR}/GOOD/kernel/pipelines/plots/panels/{config.ood_dirname}/'
            if not os.path.exists(path):
                os.makedirs(path)
            path += f"{config.load_split}_{config.dataset.dataset_name}_{config.dataset.domain}_{config.util_model_dirname}_{config.random_seed}"
            plt.savefig(path + ".png")
            plt.close()
            print("File saved in ", path)

def print_hist(args):
    load_splits = ["id"]
    for l, load_split in enumerate(load_splits):
        print("\n\n" + "-"*50)

        edge_scores_seed = []
        for i, seed in enumerate(args.seeds.split("/")):
            print(f"GENERATING PLOT FOR LOAD SPLIT = {load_split} AND SEED {seed}\n\n")
            seed = int(seed)
            args.random_seed = seed
            args.exp_round = seed
            
            config = config_summoner(args)
            config["task"] = "test"
            config["load_split"] = load_split
            if l == 0 and i == 0:
                load_logger(config)
            
            model, loader = initialize_model_dataset(config)
            ood_algorithm = load_ood_alg(config.ood.ood_alg, config)
            pipeline = load_pipeline(config.pipeline, config.task, model, loader, ood_algorithm, config)
            pipeline.load_task(load_param=True, load_split=load_split) 

            if config.dataset.dataset_name in ("BAColor", "BAColorGV", "BAColorGVIsolated"):
                print(f"\n\nClassifier weights:")
                print(model.classifierS.classifier[0].weight.detach()) #, model.classifier.classifier[0].bias.detach()

            # GET EXPLANATIONS
            ret = pipeline.get_node_explanations()

            # AGGREGATE INFO BY LABEL
            list_of_labels = np.array([ret["id_val"]["samples"][i].y.item() for i in range(len(ret["id_val"]["samples"]))])
            list_of_scores = []

            for j in range(len(ret["id_val"]["samples"])):
                expl = ret["id_val"]["scores"][j]
                
                # if "SMGNN" in config.model.model_name and "MNIST" in config.dataset.dataset_name:
                #     expl = (np.array(expl) - min(expl)) / (max(expl) - min(expl))
                #     expl = expl.tolist()

                list_of_scores.extend(
                    expl
                )
            list_of_scores = np.array(list_of_scores)
            
            # PLOT HISTOGRAMS
            n_row = 1
            n_col = 1
            fig, axs = plt.subplots(n_row, n_col, figsize=(12,3.5*n_row))

            axs.hist(
                list_of_scores, #np.random.normal(0.0, scale=0.005, size=list_of_scores.shape)
                density=True,
                log=False,
                bins=100
            )
            if "DIR" in config.model.model_name:
                axs.set_xlim(-1.1, 1.1)
            else:
                axs.set_xlim(-0.1, 1.1)
            # axs.set_ylim(0.0, 100)                

            fig.supxlabel('explanation relevance scores', fontsize=13)
            fig.supylabel('density', fontsize=13)
            fig.suptitle(f'{config.model.model_name} seed {seed}', fontsize=13)
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            
            path = f'{ROOT_DIR}/GOOD/kernel/pipelines/plots/panels/{config.ood_dirname}/'
            if not os.path.exists(path):
                os.makedirs(path)
            path += f"{config.load_split}_{config.dataset.dataset_name}_{config.dataset.domain}_{config.util_model_dirname}_{config.random_seed}"
            print("Saving at ", path)
            plt.savefig(path + ".png")
            plt.close()

def plot_explanations(args):
    load_splits = ["id"]
    split = "id_val"
    for l, load_split in enumerate(load_splits):
        print("\n\n" + "-"*50)

        for i, seed in enumerate(args.seeds.split("/")):
            print(f"GENERATING PLOT FOR LOAD SPLIT = {load_split} AND SEED {seed}\n\n")
            seed = int(seed)        
            args.random_seed = seed
            args.exp_round = seed
            
            config = config_summoner(args)
            config["task"] = "test"
            config["load_split"] = load_split
            if l == 0 and i == 0:
                load_logger(config)
            
            model, loader = initialize_model_dataset(config)
            ood_algorithm = load_ood_alg(config.ood.ood_alg, config)
            pipeline = load_pipeline(config.pipeline, config.task, model, loader, ood_algorithm, config)
            pipeline.load_task(load_param=True, load_split=load_split) 

            if config.dataset.dataset_name in ("BAColor", "BAColorGV", "BAColorGVIsolated"):
                print(f"\n\nClassifier weights:")
                print(model.classifierS.classifier[0].weight.detach())

            normalize = False
            if "SMGNN" in config.model.model_name and config.dataset.dataset_name == "MNIST":
                print("\n\nNormalizing scores in [0,1]\n\n")
                normalize = True

            # GET EXPLANATIONS
            N = 20
            ret = pipeline.get_node_explanations(num_samples=N)

            # DEFINE IMPORTANCE THRESHOLD
            if "DIR" in config.model.model_name:
                thr = 0.0 # just filter based on topK
            else:
                thr = 0.5

            # PLOT GRAPHS
            compute_fid = False
            topK_nodes_kept = None
            for i in range(len(ret[split]["samples"])):
                data = ret[split]["samples"][i].cpu()
                expl = ret[split]["scores"][i]                

                if normalize:
                    expl = (np.array(expl) - min(expl)) / (max(expl) - min(expl))

                data.node_expl = torch.tensor(expl, device=data.x.device)
                data.node_mask = data.node_expl >= thr
                data.edge_mask = torch.ones_like(data.edge_index[0])
                
                
                if "MNIST" in config.dataset.dataset_name:
                    pred = ret[split]["pred"][i].argmax(dim=0)
                else:
                    if config.metric.dataset_task == 'Multi-label classification':
                        pred = round(ret[split]["pred"][i].softmax(dim=0)[1].item(), 3)
                    else:
                        pred = round(ret[split]["pred"][i].sigmoid().item(), 3)

                # title = f"Idx: {i:<3} Class={int(data.y.item())} Pred={pred:<5}"
                title = f"Class={int(data.y.item())} Prediction={pred:<5}"

                if "DIR" in config.model.model_name:
                    if i == 0:
                        print(f"Highlightinh nodes in the Top-{config.ood.ood_param}%")

                    # Highlight TopK nodes
                    (causal_x, causal_edge_index, causal_edge_attr, causal_batch, causal_node_weight), \
                        (conf_x, conf_edge_index, conf_edge_attr, conf_batch, conf_node_weight), \
                            (topK_nodes_kept, topK_nodes_removed) = split_graph_node(data, torch.tensor(expl, device=data.x.device), config.ood.ood_param, embed=None, use_input_feat=True)
                    topK_nodes_kept = topK_nodes_kept.cpu().tolist()

                    # remove nodes not in the TopK
                    assert len(topK_nodes_kept) + topK_nodes_removed.shape[0] == data.x.shape[0]
                    data.node_mask[topK_nodes_removed] = False

                if compute_fid:
                    for metric in ["fidm", "fidp", "rfidm", "rfidp", "suff", "suff_cause"]:
                        val = pipeline.compute_metric(metric=metric, graphs=[data], graphs_nx=None, avg_graph_size=None, log_info=False)[0]["all_predicted"][0]
                        title += f" {metric}={val:.2f}"
                    # fidm = pipeline.compute_metric(metric="fidm", graphs=[data], graphs_nx=None, avg_graph_size=None, log_info=False)[0]["all_predicted"][0]
                    # fidp = pipeline.compute_metric(metric="fidp", graphs=[data], graphs_nx=None, avg_graph_size=None, log_info=False)[0]["all_predicted"][0]
                    # suff_cause = pipeline.compute_metric(metric="suff_cause", graphs=[data], graphs_nx=None, avg_graph_size=None, log_info=False)[0]["all_predicted"][0]
                    # title += f" FIDM={fidm:.2f} FIDP={fidp:.2f} SUFF_CAUSE={suff_cause:.2f}"

                if "SST2" in config.dataset.dataset_name:
                    xai_utils.plot_sentence_graph(
                        G=data,
                        name=f"graph_{split}_{i}",
                        subfolder=f"plots_of_explanation_examples/{config.prefix}{config.ood_dirname}/{config.dataset.dataset_name}_{config.dataset.domain}",
                        config=config,
                        title=title,
                    )
                else:
                    g = to_networkx(data, node_attrs=["x", "node_mask"], to_undirected=True)
                    xai_utils.draw_colored(
                        config,
                        g,
                        node_expl=expl,
                        subfolder=f"plots_of_explanation_examples/{config.prefix}{config.ood_dirname}/{config.dataset.dataset_name}_{config.dataset.domain}",
                        name=f"graph_{split}_{i}",
                        title=title,
                        with_labels=False,
                        figsize=(12,10) if "AIDS" in config.dataset.dataset_name else (6.4, 4.8),
                        topk=topK_nodes_kept
                    )
                print(f"graph {title}")
            print("Plotted in ", f"plots_of_explanation_examples/{config.prefix}{config.ood_dirname}/{config.dataset.dataset_name}_{config.dataset.domain}")
