import networkx as nx
import torch
from botorch.models import SingleTaskGP, SaasFullyBayesianSingleTaskGP
from botorch.optim import optimize_acqf, optimize_acqf_discrete
from botorch.acquisition import ExpectedImprovement
from botorch.fit import fit_gpytorch_model, fit_fully_bayesian_model_nuts
from gpytorch.mlls import ExactMarginalLogLikelihood
import matplotlib.pyplot as plt
import ndlib
import ndlib.models.epidemics as ep
import ndlib.models.ModelConfig as mc
import numpy as np
from itertools import combinations as comb
import pandas as pd
import heapdict as hd

from graphGeneration import *
import random
import math

import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn.inits import reset
from torch_geometric.data import Data
from sklearn.cluster import KMeans, SpectralClustering
import statistics as s

################################################
# Global parameters
################################################
diffusion_model = "ic" # "ic" or "lt"
graph_size = 1000
candidate_size = 50 # candidate pool size
number_of_sources = 3 # budget for IM
num_iterations = 300 # budget for BO
actual_time_step_size = 5 # diffusion parameter
allowed_shortest_distance = 1 # shortest distance between sources for filtering
num_of_sims = 10
number_of_clusters = 20

from botorch.models.utils import fantasize as fantasize_flag, validate_input_scaling
from gpytorch.models.exact_gp import ExactGP
from gpytorch.means.constant_mean import ConstantMean
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import RBFKernel, RFFKernel
from typing import Any, List, NoReturn, Optional, Union
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.module import Module
from gpytorch.means.mean import Mean
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import Log, OutcomeTransform
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
from gpytorch.constraints.constraints import GreaterThan
from gpytorch.priors.torch_priors import GammaPrior

def get_gaussian_likelihood_with_gamma_prior(
    batch_shape: Optional[torch.Size] = None,
) -> GaussianLikelihood:
    r"""Constructs the GaussianLikelihood that is used by default by
    several models. This uses a Gamma(1.1, 0.05) prior and constrains the
    noise level to be greater than MIN_INFERRED_NOISE_LEVEL (=1e-4).
    """
    batch_shape = torch.Size() if batch_shape is None else batch_shape
    noise_prior = GammaPrior(1.1, 0.05)
    noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
    return GaussianLikelihood(
        noise_prior=noise_prior,
        batch_shape=batch_shape,
        noise_constraint=GreaterThan(
            1e-4,
            transform=None,
            initial_value=noise_prior_mode,
        ),
    )

# define a class inherited from SingleTaskGP for RBF/rff kernel
class RBFSingleTaskGP(SingleTaskGP):

    def __init__(self, train_X: torch.Tensor,
        train_Y: torch.Tensor,
        likelihood: Optional[Likelihood] = None,
        covar_module: Optional[Module] = None,
        mean_module: Optional[Mean] = None,
        outcome_transform: Optional[OutcomeTransform] = None,
        input_transform: Optional[InputTransform] = None,
    ) -> None:
        r"""
        Args:
            train_X: A `batch_shape x n x d` tensor of training features.
            train_Y: A `batch_shape x n x m` tensor of training observations.
            likelihood: A likelihood. If omitted, use a standard
                GaussianLikelihood with inferred noise level.
            covar_module: The module computing the covariance (Kernel) matrix.
                If omitted, use a `RBF`.
            mean_module: The mean function to be used. If omitted, use a
                `ConstantMean`.
            outcome_transform: An outcome transform that is applied to the
                training data during instantiation and to the posterior during
                inference (that is, the `Posterior` obtained by calling
                `.posterior` on the model will be on the original scale).
            input_transform: An input transform that is applied in the model's
                forward pass.
        """
        with torch.no_grad():
            transformed_X = self.transform_inputs(
                X=train_X, input_transform=input_transform
            )
        if outcome_transform is not None:
            train_Y, _ = outcome_transform(train_Y)
        self._validate_tensor_args(X=transformed_X, Y=train_Y)
        ignore_X_dims = getattr(self, "_ignore_X_dims_scaling_check", None)
        validate_input_scaling(
            train_X=transformed_X, train_Y=train_Y, ignore_X_dims=ignore_X_dims
        )
        self._set_dimensions(train_X=train_X, train_Y=train_Y)
        train_X, train_Y, _ = self._transform_tensor_args(X=train_X, Y=train_Y)
        if likelihood is None:
            likelihood = get_gaussian_likelihood_with_gamma_prior(
                batch_shape=self._aug_batch_shape
            )
        else:
            self._is_custom_likelihood = True
        ExactGP.__init__(
            self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood
        )
        if mean_module is None:
            mean_module = ConstantMean(batch_shape=self._aug_batch_shape)
        self.mean_module = mean_module
        if covar_module is None:
            covar_module = RBFKernel()
            self._subset_batch_dict = {
                "likelihood.noise_covar.raw_noise": -2,
                "mean_module.raw_constant": -1,
                "covar_module.raw_outputscale": -1,
                "covar_module.base_kernel.raw_lengthscale": -3,
            }
        self.covar_module = covar_module
        # TODO: Allow subsetting of other covar modules
        if outcome_transform is not None:
            self.outcome_transform = outcome_transform
        if input_transform is not None:
            self.input_transform = input_transform
        self.to(train_X)

    def forward(self, x: torch.Tensor) -> MultivariateNormal:
        if self.training:
            x = self.transform_inputs(x)
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


# top candidate_size nodes with highest degree centrality as candidate pool

def create_candidate_set_pool_filtering(G, candidate_size=100, number_of_sources=3, allowed_shortest_distance=2):
    deg = sorted(G.degree, key=lambda x: x[1], reverse=True)

    candidates = []

    for candidate in deg[:candidate_size]:
        candidates.append(candidate[0])

    candidate_source_sets = []

    for selected_set in comb(candidates, number_of_sources):

        shortest_distance = 5
        for i in range(number_of_sources-2):
            start = selected_set[i]
            for j in range(i+1, number_of_sources-1):
                end = selected_set[j]
                distance = nx.shortest_path_length(G, source=start, target=end)
                if distance < shortest_distance:
                    shortest_distance = distance
        if shortest_distance > allowed_shortest_distance:            
            candidate_source_sets.append(selected_set)

    return candidate_source_sets

def create_candidate_set_pool(G, candidate_size=100, number_of_sources=3):
    deg = sorted(G.degree, key=lambda x: x[1], reverse=True)

    candidates = []

    for candidate in deg[:candidate_size]:
        candidates.append(candidate[0])

    candidate_source_sets = []

    for selected_set in comb(candidates, number_of_sources):

        candidate_source_sets.append(selected_set)

    return candidate_source_sets


def fourier_transfer_for_all_candidate_set(candidate_sets, UT):

    n = len(UT)

    signals = []
    for source_set in candidate_sets:
        a = [0 for i in range(n)]
        for node in source_set:
            a[node] = 1
        signal = np.matmul(a, UT)
        signals.append(signal)

    return signals

def create_signal_from_source_set(G, sampled_set, UT):

    n = len(UT)

    a = [0 for i in range(n)]
    for node in sampled_set:
        a[node] = 1
    signal = np.matmul(a, UT)

    return signal

def find_source_set_from_fourier(signal, number_of_sources, UT_inv):

    source_set = []

    a = np.matmul(signal, UT_inv)
    b = np.around(a)
    for i in range(len(b)):
        if b[i] == 1:
            source_set.append(i)

    if len(source_set) != number_of_sources:
        raise NameError('length of source set is not the estimated number')

    return source_set

################################################
# Mostly for Sobol
################################################

def combinations(alist):
  n = len(alist)
  subs = [[]]

  for item in alist:
    subs += [curr + [item] for curr in subs]
  subs.sort(key=len)
  return subs

def subcombs(alist):
  subs = combinations(alist)
  subs.remove([])
  subs.remove(alist)
  subs.sort(key=len)
  return subs

def substract(alist, blist):
  a = []
  for i in alist:
    a.append(i)
  for i in blist:
    a.remove(i)
  return a

def diff(rank, order):
    n = len(rank)
    if (len(order) != n):
      print('the lengths do not match')
      pass
    else:
      difference = 0
      for i in range(n):
        item = rank[i]
        index = order.index(item)
        delta = abs(index - i)
        difference += delta
    return difference

def simulationIC(r, g, result, config):

    title = []
    for i in result:
        title.append(i)
    title.append('result')

    df = pd.DataFrame(columns=title)

    n = len(result)

    for combs in combinations(result):
        input = []
        for i in range(n):
            item = 1 if result[i] in combs else 0
            input.append(item)

        for i in range(r):

            input1 = []
            for item in input:
                input1.append(item)

            g_mid = g.__class__()
            g_mid.add_nodes_from(g)
            g_mid.add_edges_from(g.edges)

            model_mid = ep.IndependentCascadesModel(g_mid)
            config_mid = mc.Configuration()
            config_mid.add_model_initial_configuration('Infected', combs)

            for a, b in g_mid.edges():
                weight = config.config["edges"]['threshold'][(a, b)]
                g_mid[a][b]['weight'] = weight
                config_mid.add_edge_configuration('threshold', (a, b), weight)

            model_mid.set_initial_status(config_mid)

            iterations = model_mid.iteration_bunch(actual_time_step_size)
            trends = model_mid.build_trends(iterations)

            total_no = 0

            for j in range(actual_time_step_size):
                a = iterations[j]['node_count'][1]
                total_no += a

            input1.append(total_no)

            newdf = pd.DataFrame([input1], columns=title)

            df = pd.concat([df,newdf])
    return df

def simulationLT(r, g, result, config):

    title = []
    for i in result:
        title.append(i)
    title.append('result')

    df = pd.DataFrame(columns=title)

    n = len(result)

    for combs in combinations(result):
        input = []
        for i in range(n):
            item = 1 if result[i] in combs else 0
            input.append(item)

        for i in range(r):

            input1 = []
            for item in input:
                input1.append(item)

            g_mid = g.__class__()
            g_mid.add_nodes_from(g)
            g_mid.add_edges_from(g.edges)

            model_mid = ep.ThresholdModel(g_mid)
            config_mid = mc.Configuration()
            config_mid.add_model_initial_configuration('Infected', combs)

            for a, b in g_mid.edges():
                weight = config.config["edges"]['threshold'][(a, b)]
                g_mid[a][b]['weight'] = weight
                config_mid.add_edge_configuration('threshold', (a, b), weight)

            for i in g_mid.nodes():
                threshold = random.randrange(1, 20)
                threshold = round(threshold / 100, 2)
                config_mid.add_node_configuration("threshold", i, threshold)

            model_mid.set_initial_status(config_mid)

            iterations = model_mid.iteration_bunch(actual_time_step_size)
            trends = model_mid.build_trends(iterations)

            total_no = iterations[actual_time_step_size-1]['node_count'][1]
            input1.append(total_no)

            newdf = pd.DataFrame([input1], columns=title)

            df = pd.concat([df,newdf])
    return df

def SobolT(df, result):
    sobolt = {}

    for node in result:

        backup = []
        for item in result:
            backup.append(item)

        backup.remove(node)

        var = []

        for sub in combinations(backup):

            means = []

            for case in combinations([node]):

                total = []

                seeds = sub + case

                subdf = df

                for item in result:
                    if item in seeds:
                        a = (subdf[item] == 1)
                    else:
                        a = (subdf[item] == 0)

                    subdf = subdf[a]

                means.append(s.mean(subdf['result']))
            var.append(s.pvariance(means))

        sobolt[node] = s.mean(var)

    return sobolt

def effectIC(g, config, sources,rounds=10):

    input = []

    for i in range(rounds):

      g_mid = g.__class__()
      g_mid.add_nodes_from(g)
      g_mid.add_edges_from(g.edges)

      model_mid = ep.IndependentCascadesModel(g_mid)
      config_mid = mc.Configuration()
      config_mid.add_model_initial_configuration('Infected', sources)

      for a, b in g_mid.edges():
        weight = config.config["edges"]['threshold'][(a, b)]
        g_mid[a][b]['weight'] = weight
        config_mid.add_edge_configuration('threshold', (a, b), weight)

      model_mid.set_initial_status(config_mid)

      iterations = model_mid.iteration_bunch(actual_time_step_size)
      trends = model_mid.build_trends(iterations)

      total_no = 0

      for j in range(actual_time_step_size):
        a = iterations[j]['node_count'][1]
        total_no += a

      input.append(total_no)

    e = s.mean(input)
    v = s.stdev(input)

    return e,v

def effectLT(g, config, sources,rounds=10):

    input = []

    for i in range(rounds):

      g_mid = g.__class__()
      g_mid.add_nodes_from(g)
      g_mid.add_edges_from(g.edges)

      model_mid = ep.ThresholdModel(g_mid)
      config_mid = mc.Configuration()
      config_mid.add_model_initial_configuration('Infected', sources)

      for a, b in g_mid.edges():
        weight = config.config["edges"]['threshold'][(a, b)]
        g_mid[a][b]['weight'] = weight
        config_mid.add_edge_configuration('threshold', (a, b), weight)

      for i in g.nodes():
          threshold = random.randrange(1, 20)
          threshold = round(threshold / 100, 2)
          config_mid.add_node_configuration("threshold", i, threshold)

      model_mid.set_initial_status(config_mid)

      iterations = model_mid.iteration_bunch(actual_time_step_size)
      trends = model_mid.build_trends(iterations)

      total_no = iterations[actual_time_step_size-1]['node_count'][1]
      input.append(total_no)

    e = s.mean(input)
    v = s.stdev((input))

    return e,v

