import torch
import os
from torch import nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric.data as pyg_data
from typing import Literal
from scipy.stats import entropy
from sklearn.decomposition import PCA
import torch_geometric
import numpy as np
from typing import Tuple, Dict, List
import logging
from gnnboundary import *
from gnnboundary.utils.boundary_generator import GraphGenerator
from gnn_xai_common import GraphSampler
from scripts.default_configs import DATASET_TO_MAX_NODES_INTERPRETER


class BoundaryEvaluator:
    """
    Class to evaluate the boundary of a discriminator. It is build to reproduce the results of the paper "GNNBoundary".
    The class can be used to compute the margin, thickness, complexity and average probabilities of the boundary.
    It either generates the boundary graphs or loads them from a directory. The class also generates decision region samples
    from the dataset or loads them from a directory.
    """
    def __init__(self, sampler, discriminator, dataset, trainer, cls_pair: Tuple[int, int],
                 num_generation_iterations:int = 2000,logger=None,
                 load_dir: str = None, interpreter_load_dir: str = None,
                 temperature: float = 0.2, max_nodes: int = 25,
                 order_of_pairs: Tuple[int, int] = None,
                 ):
        self.sampler = sampler
        self.discriminator = discriminator
        self.dataset = dataset
        self.order_of_pairs = order_of_pairs

        self.load_dir = load_dir
        self.interpreter_load_dir = interpreter_load_dir
        self.temperature = temperature
        self.max_nodes = max_nodes
        self.cls_pair = cls_pair

        self.boundary_generator = GraphGenerator(sampler, dataset, trainer, model=discriminator)
        self.num_iteratiohn_generations = num_generation_iterations

        self.boundary_graphs = None

        self.boundary_embeddings: torch.Tensor = None
        self.decision_region_samples = None

        if logger is None:
            logging.basicConfig(
                level=logging.INFO,
                format="%(asctime)s - %(levelname)s - %(message)s",
                handlers=[
                    logging.StreamHandler(),
                    logging.FileHandler("experiment.log", mode="w")
                ]
            )

            logger = logging.getLogger(__name__)

        self.logger = logger


    def _extract_boundary_embeddings(self, level: Literal['linear', 'conv'] = "linear", num_graphs=2):
        graphs = pyg_data.Batch.from_data_list(self.boundary_graphs[:num_graphs])
        output = self.discriminator(graphs, edge_weight=graphs.edge_weight)
        self.boundary_embeddings = output["embeds_last"]

    def _extract_decision_region_embeddings(self, level: Literal['linear', 'conv'] = "linear",
                                            num_graphs: int = 2):
        graphs = pyg_data.Batch.from_data_list(self.decision_region_samples[:num_graphs])
        output = self.discriminator(graphs, edge_weight=graphs.edge_weight)
        self.decision_region_embeddings = output["embeds_last"]

    def _generate_boundary_graphs(self, num_graphs=10, **kwargs):

        if self.load_dir is not None:
            self.logger.info("Loading boundary %s graphs from %s", num_graphs, self.load_dir)
            samplers = []
            files = os.listdir(self.load_dir)
            if num_graphs > len(files):
                num_graphs = len(files)
                self.logger.warning("Number of graphs to load is greater than the number of files in the directory. Loading %s graphs", num_graphs)

            for i in range(num_graphs):
                sampler = GraphSampler(max_nodes=self.max_nodes, temperature=self.temperature,
                                       num_node_cls=len(self.dataset.NODE_CLS), learn_node_feat=True)
                sampler.load_state_dict(torch.load(os.path.join(self.load_dir, files[i])))
                samplers.append(sampler)

            self.boundary_graphs = [sampler(k=1, mode='discrete', expected=True) for sampler in samplers]
            self.logger.info("Finished loading boundary %s graphs", num_graphs)

        else:
            self.logger.info("Generating boundary %s graphs....", num_graphs)
            kwargs['iterations'] = self.num_iteratiohn_generations
            self.boundary_graphs = self.boundary_generator(num_runs=1,
                                                          num_graphs=num_graphs,
                                                          logger=self.logger
                                                          **kwargs)[0]
            self.logger.info("Finished generating boundary %s graphs", num_graphs)

    def inject_interpreter_path(self, interpreter_load_dir: str):
        self.interpreter_load_dir = interpreter_load_dir
        self.decision_region_samples = None
        self.decision_region_embeddings = None

    def _generate_decision_region_samples(self, num_graphs=10, cls_label:int = None, **kwargs):
        if self.interpreter_load_dir is not None:
            self.logger.info("Loading decision region samples from %s using GNNInterpreter", self.interpreter_load_dir)
            samplers = []
            files = os.listdir(self.interpreter_load_dir)
            if num_graphs > len(files):
                num_graphs = len(files)
                self.logger.warning(
                    "Number of graphs to load is greater than the number of files in the directory. Loading %s interpreter graphs",
                    num_graphs)
            for i in range(num_graphs):
                sampler = GraphSampler(max_nodes=DATASET_TO_MAX_NODES_INTERPRETER[self.dataset.name][self.order_of_pairs[0]],
                                       temperature=self.temperature,
                                       num_node_cls=len(self.dataset.NODE_CLS), learn_node_feat=True)
                try:
                    sampler.load_state_dict(torch.load(os.path.join(self.interpreter_load_dir, files[i])))
                except:
                    self.logger.error("Error loading file %s", files[i])
                    continue
                samplers.append(sampler)

            self.decision_region_samples = [self.dataset.convert(sampler.sample()) for sampler in samplers]
            self.logger.info("Finished loading decision region samples using GNNInterpreter")
        else:
            self.logger.info("Generating decision region samples from dataset....")
            assert self.order_of_pairs is not None, "order of pairs must be provided if interpreter_load_dir is not provided"
            self.decision_region_samples = pyg_data.Batch.from_data_list([x for x in self.dataset if x.y == self.order_of_pairs[0]])
            self.logger.info("Finished generating decision region samples from dataset")

    def _sample_boundary_graphs(self, num_graphs=10):
        self.boundary_graphs = [self.dataset.convert(self.sampler.sample(), generate_label=True) for _ in range(num_graphs)]

    def boundary_margin(self, num_graphs: int, **kwargs):
        """
        Computes the margin of the boundary for a given class label.
        The margin is defined as the minimum distance of the boundary embeddings to the class embeddings.
        params:
            cls_label: int: class label for which the margin is computed
            num_graphs: int: number of graphs to generate for the boundary. Distance will be computed for each generated, data pair
            **kwargs: dict: additional arguments for the boundary generation

        """
        if self.boundary_graphs is None or len(self.boundary_graphs) != num_graphs:
            self._generate_boundary_graphs(num_graphs=num_graphs, **kwargs)
        self._extract_boundary_embeddings(level="linear")

        if self.decision_region_samples is None or len(self.decision_region_samples) != num_graphs:
            self._generate_decision_region_samples(num_graphs=num_graphs, cls_label=self.order_of_pairs[0], **kwargs)
        self._extract_decision_region_embeddings(level="linear", num_graphs=num_graphs)

        distances = torch.cdist(self.boundary_embeddings, self.decision_region_embeddings, p=2)
        min_dist = torch.min(distances)

        return min_dist.item()

    def boundary_thickness(self, gamma: float = 0.75,
                           num_graphs = 500,
                           num_interpolation_points: int = 1000, **kwargs):
        """ Computes the boundary thickness for given class labels using a single boundary graph. """
        if self.boundary_graphs is None or len(self.boundary_graphs) == 0:
            self._generate_boundary_graphs(num_graphs=1, **kwargs)
        self._extract_boundary_embeddings(level="linear", num_graphs=num_graphs)

        if self.decision_region_samples is None or len(self.decision_region_samples) == 0:
            self._generate_decision_region_samples(num_graphs=1, cls_label=self.order_of_pairs[0], **kwargs)

        self._extract_decision_region_embeddings(level="linear", num_graphs=num_graphs)

        thicknesses = []
        for (decision_region_embedding, boundary_embedding) in zip(self.decision_region_embeddings, self.boundary_embeddings):
            t = torch.linspace(0, 1, num_interpolation_points).view(-1, 1)
            h_t = (1 - t) * decision_region_embedding + t * boundary_embedding

            logits = self.discriminator.out(h_t)
            probs = F.softmax(logits, dim=1)
            prob_diff = probs[:, self.order_of_pairs[0]] - probs[:, self.order_of_pairs[1]]
            indicator = (gamma > prob_diff).float()
            denominator = indicator.mean()

            if denominator > 0:
                thicknesses.append(torch.norm(decision_region_embedding - boundary_embedding, p=2) / denominator)

        if thicknesses:
            thickness = torch.mean(torch.tensor(thicknesses)).item()
        else:
            thickness = 0.0

        return thickness

    def boundary_complexity(self, num_graphs: int, use_pca: bool = True, **kwargs):
        """
        Computes the boundary complexity of the discriminator.
        The boundary complexity is defined as the ratio of the entropy of the eigenvalues of the covariance matrix of the boundary embeddings
        to the logarithm of the dimensionality of the embeddings.
        params:
            num_graphs: int: number of graphs to generate for the boundary
            use_pca: bool: if True, PCA is used to compute the complexity. Otherwise, the eigenvalues of the covariance matrix are used.
            **kwargs: dict: additional arguments for the boundary generation
        """

        assert num_graphs > 1, "Number of graphs must be greater than 1"

        if self.boundary_graphs is None or len(self.boundary_graphs) != num_graphs:
           self._generate_boundary_graphs(num_graphs=num_graphs, **kwargs)


        self._extract_boundary_embeddings(level="linear", num_graphs=num_graphs)
        # https://github.com/ShuyueG/decision-boundary-complexity-score/blob/main/compute_DBC_local.py

        if use_pca:
            pca = PCA()
            pca.fit(self.boundary_embeddings.squeeze().detach().cpu().numpy())
            cplx = pca.explained_variance_ratio_
            cplx_ratio = entropy(cplx, base=2) / np.log2(len(cplx))
            return cplx_ratio

        else:
            n_samples = self.boundary_embeddings.shape[0]
            boundary_embeddings_mean_centered = self.boundary_embeddings - self.boundary_embeddings.mean(dim=1)
            cov = boundary_embeddings_mean_centered.T @ boundary_embeddings_mean_centered / (n_samples)
            eigenvalues = torch.linalg.eig(cov).eigenvalues.real
            normalized_eigenvalues = eigenvalues / torch.norm(eigenvalues)
            H = - torch.dot(normalized_eigenvalues, torch.log(normalized_eigenvalues + 1e-8))
            D = self.boundary_embeddings.shape[1]
            complexity = H / torch.log(torch.tensor(D))

            return complexity.item()

    def average_probs(self, num_graphs: int, **kwargs):
        """
        Computes the average probability of the discriminator for each class label.
        """
        if self.boundary_graphs is None or len(self.boundary_graphs) != num_graphs:
            self._generate_boundary_graphs(num_graphs=num_graphs, **kwargs)

        data_batch = torch_geometric.data.Batch.from_data_list(self.boundary_graphs)

        probs = self.discriminator(data_batch, edge_weight=data_batch.edge_weight)["probs"].detach().cpu()

        probs = np.array(probs)
        mean_probs = np.mean(probs, axis=0)
        std_probs = np.std(probs, axis=0)

        return mean_probs[[self.cls_pair[0], self.cls_pair[1]]], std_probs[[self.cls_pair[0], self.cls_pair[1]]]

    def _split_graphs_based_on_probability_range(self, ranges: Tuple[Tuple[float, float]]) -> Dict[Tuple[float, float],
    List[torch_geometric.data.Data]]:

        range_dict = {r: [] for r in ranges}
        data_batch = torch_geometric.data.Batch.from_data_list(self.boundary_graphs)
        probs = self.discriminator(data_batch, edge_weight=data_batch.edge_weight)["probs"].detach().cpu()

        for (graph, row) in zip(self.boundary_graphs, probs):
            x, y = row[self.cls_pair[0]], row[self.cls_pair[1]]
            for r in ranges:
                if r[0] <= x <= r[1] and r[0] <= y <= r[1]:
                    range_dict[r].append(graph)

        return range_dict

    def boundary_complexity_for_probability_range(self, ranges: Tuple[Tuple[float, float]],
                                                  num_graphs: int, use_pca: bool = True, **kwargs):
        """
        Computes the boundary complexity for a given probability range. The graphs are filtered based on the probability range
        and the complexity is computed for the filtered graphs.
        params:
            ranges: Tuple[Tuple[float, float]]: list of probability ranges for which the complexity is computed
            num_graphs: int: number of graphs to generate for the boundary
            use_pca: bool: if True, PCA is used to compute the complexity. Otherwise, the eigenvalues of the covariance matrix are used.
            **kwargs: dict: additional arguments for the boundary
        """

        if self.boundary_graphs is None or len(self.boundary_graphs) != num_graphs:
            self._generate_boundary_graphs(num_graphs=num_graphs, **kwargs)

        all_graphs = self.boundary_graphs
        range_dict = self._split_graphs_based_on_probability_range(ranges)
        boundary_complexity_dict = {}

        for r, graphs in range_dict.items():
            if len(graphs) < 2:
                boundary_complexity_dict[r] = float('nan')
                continue
            self.boundary_graphs = graphs
            boundary_complexity_dict[r] = self.boundary_complexity(num_graphs=len(graphs), use_pca=use_pca, **kwargs)

        self.boundary_graphs = all_graphs
        return boundary_complexity_dict

    def boundary_margin_for_probability_range(self, ranges: Tuple[Tuple[float, float]], num_graphs: int, **kwargs):
        """
        Computes the boundary margin for a given probability range. The graphs are filtered based on the probability range
        and the margin is computed for the filtered graphs.
        params:
            ranges: Tuple[Tuple[float, float]]: list of probability ranges for which the margin is computed
            num_graphs: int: number of graphs to generate for the boundary
            **kwargs: dict: additional arguments for the boundary
        """
        if self.boundary_graphs is None or len(self.boundary_graphs) != num_graphs:
            self._generate_boundary_graphs(num_graphs=num_graphs, **kwargs)
        all_graphs = self.boundary_graphs
        range_dict = self._split_graphs_based_on_probability_range(ranges)
        boundary_margin_dict = {}

        for r, graphs in range_dict.items():
            if len(graphs) < 2:
                boundary_margin_dict[r] = float('nan')
                continue
            self.boundary_graphs = graphs
            boundary_margin_dict[r] = self.boundary_margin(num_graphs=len(graphs), **kwargs)

        self.boundary_graphs = all_graphs
        return boundary_margin_dict

    def boundary_thickness_for_probability_range(self, ranges: Tuple[Tuple[float, float]], num_graphs: int, **kwargs):
        """
        Computes the boundary thickness for a given probability range. The graphs are filtered based on the probability range
        and the thickness is computed for the filtered graphs.
        params:
            ranges: Tuple[Tuple[float, float]]: list of probability ranges for which the thickness is computed
            num_graphs: int: number of graphs to generate for the boundary
            **kwargs: dict: additional arguments for the boundary
        """
        if self.boundary_graphs is None or len(self.boundary_graphs) != num_graphs:
            self._generate_boundary_graphs(num_graphs=num_graphs, **kwargs)
        all_graphs = self.boundary_graphs
        range_dict = self._split_graphs_based_on_probability_range(ranges)
        boundary_thickness_dict = {}

        for r, graphs in range_dict.items():
            if len(graphs) < 2:
                boundary_thickness_dict[r] = float('nan')
                continue
            self.boundary_graphs = graphs
            boundary_thickness_dict[r] = self.boundary_thickness(num_graphs=len(graphs), **kwargs)
        self.boundary_graphs = all_graphs
        return boundary_thickness_dict