import os
import sys
sys.path.append(os.getcwd() + '/linreg/')
sys.path.append(os.getcwd() + '/logreg/')

import numpy as np
import bisect
from math import factorial
from itertools import combinations
from operator import itemgetter
from run_methods import run_methods
from NIG import NIG_rvs
from kl_estimation import compute_kl_gibbs, compute_kl_log
from copy import copy, deepcopy

import arviz as az
from cmdstanpy import CmdStanModel

from abc import ABC, abstractmethod, ABCMeta
from log_generate_data import generate_data_config

class RequiredMeta(ABCMeta):
    required_attributes = []

    def __call__(self, *args, **kwargs):
        obj = super(RequiredMeta, self).__call__(*args, **kwargs)
        for attr_name in obj.required_attributes:
            if not getattr(obj, attr_name):
                raise ValueError('required attribute (%s) not set' % attr_name)
        return obj

class AbstractValuation(object, metaclass=RequiredMeta):
    """
    Computes the Shapley value of each party.
    Note that the reweighted Shapley values to prevent negative values is not included.
    """
    required_attributes = ['num']

    def __init__(self):
        """To be called after subclass init, use abstract methods"""
        self.power_set = [list(j) for i in range(self.num) for j in combinations(list(range(self.num)), i+1)]
        self.all_kls = dict()

        if self.num <= 6:
            self._compute_all_kls()
            self.update_v_and_shapleys()

        # else:
        #     self.each_kls =  np.array([self._compute_kl_save_posterior([i]) for i in range(self.num)])
        #     self.grand_kl = self._compute_kl_save_posterior(list(range(self.num)))
        #     self.shapleys = self.approx_shap(5000, write=True)

    def update_v_and_shapleys(self):
        # This relies on Python 3.7+ python dictionary are guaranteed to be insertion ordered
        self.each_kls = np.array(list(self.all_kls.values())[:self.num])
        self.grand_kl = list(self.all_kls.values())[-1]
        self.shapleys = self.calculate_shap(list(self.all_kls.values()))

    def _compute_all_kls(self):
        for subset in self.power_set:
            print(subset)
            if not subset:
                continue
            self._compute_kl_save_posterior(subset)

    @abstractmethod
    def _compute_kl_save_posterior(self, subset):
        """Compute the model/posterior corresponding to subset and save the posterior/kl as attributes"""
        pass

    @abstractmethod
    def update_i(self, i, curr_div_new_eps_ratio):
        pass

    def calculate_shap(self, vs):
        # vs can use to compute other functions
        power_set = self.power_set
        n = self.num

        shapley_values = []
        for i in range(n):
            shapley = 0
            for j in power_set:
                if i not in j:
                    card_s = len(j)
                    Cui = j[:]
                    bisect.insort_left(Cui,i)
                    l = power_set.index(j)
                    k = power_set.index(Cui)
                    shapley += (vs[k] - vs[l]) * factorial(card_s) * factorial(n - card_s - 1) / factorial(n)

            card_s = 0
            Cui = [i]
            k = power_set.index(Cui)
            shapley += vs[k] * factorial(card_s) * factorial(n - card_s - 1) / factorial(n)
            shapley_values.append(shapley)
        return np.array(shapley_values)

    def approx_shap(self, rounds=500, write=False):
        approx_shap = np.zeros(self.num)
        kl_hash = dict()

        for r in range(rounds):
            order = np.random.permutation(self.num).tolist()
            old_kls = 0
            for i in range(1,self.num+1):
                if frozenset(order[:i]) in kl_hash:
                    new_kls = kl_hash[frozenset(order[:i])]
                else:
                    subset = order[:i]
                    new_kls = self._compute_kl_save_posterior(subset)
                    kl_hash[frozenset(order[:i])] = new_kls
                approx_shap[order[i-1]] += new_kls - old_kls
                old_kls = new_kls

            if write:
                self.all_kls = kl_hash

        return approx_shap / rounds

    def get_shapley_ratio_sum_to_1(self):
        return self.shapleys / np.sum(self.shapleys)

    def get_shapley_ratio_max_1(self):
        return self.shapleys / np.max(self.shapleys)

    def get_reward_ratio(self, p):
        return self.get_shapley_ratio_max_1() ** p

    def get_reward(self, p):
        return self.get_reward_ratio(p) * self.grand_kl

    def summarize(self, p=1):
        print("KL", self.all_kls)
        print("Shapley: ", self.shapleys)
        print("Shapley Power ratio: ", self.get_reward_ratio(p))
        print("Reward", self.get_reward(p))
        print("Rational p", self.get_rational_p())

    def get_rational_p(self):
        target_v = self.each_kls
        ps = np.log(target_v/self.grand_kl) / np.log(self.get_shapley_ratio_max_1())
        ps = ps[np.isfinite(ps)]
        p = np.min(ps)
        return np.min([p, 1])

    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        for k, v in self.__dict__.items():
            # do not want them to share the same self.all_posteriors / self.all_kls list
            # will be updated during test for monotonicity
            setattr(result, k, copy(v))
        return result


class Valuation(AbstractValuation):
    def __init__(self, data_prior_params, model_prior_params, nParty, S, Z, sigma_DP, nPoint,
                 nSample=3000, DP_method='Gaussian',inf_method='naive',mode='nearestc',chains=4):
        self.data_prior_params = data_prior_params
        self.model_prior_params = model_prior_params
        self.num = nParty
        self.S = copy(S)
        self.Z = copy(Z)
        self.sigma_DP = copy(sigma_DP)
        self.nPoint = nPoint
        self.nSample = nSample
        self.DP_method = DP_method
        self.inf_method = inf_method
        self.mode = mode
        self.prior = NIG_rvs(*self.model_prior_params, size=self.nSample)

        self.all_posteriors = dict()
        self.chains = chains

        super().__init__()

    def _compute_kl_save_posterior(self, subset):
        key = ''.join(str(e) for e in subset)

        posteriors = run_methods(self.data_prior_params, self.model_prior_params, itemgetter(*subset)(self.S),
                                itemgetter(*subset)(self.Z), self.sigma_DP[subset], self.nPoint[subset], self.nSample, self.DP_method,
                                [self.inf_method], chains=self.chains)
        self.n_dict_keys = len(posteriors)
        posteriors['prior'] = self.prior
        self.all_posteriors[key] = posteriors

        self.all_kls[key] = 0
        for dict_key in posteriors.keys():
            if dict_key != 'prior':
                self.all_kls[key] += compute_kl_gibbs(posteriors, dict_key, mode=self.mode, model_prior_params=self.model_prior_params)
        self.all_kls[key] /= self.n_dict_keys

        return self.all_kls[key]

    def recompute_all_kls_and_shap(self, mode):
        # Assume posterior have been computed
        self.new_kls = dict()
        for subset in self.power_set:
            print(subset)
            key = ''.join(str(e) for e in subset)
            if not subset:
                continue
            self.new_kls[key] = compute_kl_gibbs(self.all_posteriors[key], self.inf_method, mode=mode, model_prior_params=self.model_prior_params)

        self.new_shapleys = self.calculate_shap(list(self.new_kls.values()))
        print("KL", self.new_kls)
        print("{} Shapley: ".format(mode), self.new_shapleys)

    def update_i(self, i, curr_div_new_eps_ratio):
        """ curr_div_new_eps_ratio = curr_eps / target_eps. If >1, mean smaller target eps, higher privacy """
        self.sigma_DP[i] = self.sigma_DP[i] * np.sqrt(curr_div_new_eps_ratio)
        self.Z[i] = {key: (self.Z[i][key] - val) * np.sqrt(curr_div_new_eps_ratio) + val for key, val in self.S[i].items()}
        self.Z[i]['X'] = self.Z[i]['XX'][:, 0][:, None] # what is it used for?

        if self.num <= 6:
            for subset in self.power_set:
                print(subset)
                if (not subset) or (i not in subset):
                    continue
                self._compute_kl_save_posterior(subset)
            # This relies on Python 3.7+ python dictionary are guaranteed to be insertion ordered
            self.update_v_and_shapleys()
        else:
            raise NotImplementedError("recompute from scratch?")

class LogValuation(AbstractValuation):
    def __init__(self, model, pass_object, l2_norm, d,
                 nParty, multiN, approx_ss, perturbed_ss, variances,
                 DP_method='Gaussian', chains=25, burn_in=400, nsamples=2000, split_num=5):
        """
        model = CmdStanModel file
        pass_object = PASS
        l2_norm, d = norm and dimension of a single input
        multiN = np array of number of data points per party
        """
        self.model = CmdStanModel(stan_file=model)
        self.pass_object = pass_object
        self.l2_norm = l2_norm
        self.d = d

        self.num = nParty
        self.N = multiN
        self.S = approx_ss.copy()
        self.Z = perturbed_ss.copy()
        self.variances = variances.copy()

        self.DP_method = DP_method

        self.all_posteriors = dict()

        self.chains = chains
        self.burn_in = burn_in
        self.nsamples = nsamples
        self.split_num = split_num

        super().__init__()

    def _compute_kl_save_posterior(self, subset):
        key = ''.join(str(e) for e in subset)

        new_config = generate_data_config(self.N[subset],
                                        self.d,
                                        len(subset), # num of parties
                                        self.l2_norm, self.pass_object,
                                        self.Z[subset], self.variances[subset])

        fit = self.model.sample(data=new_config,
                           iter_warmup=self.burn_in, iter_sampling=self.nsamples, chains=self.chains,
                           adapt_delta=0.86, show_progress=False)

        az_data = az.from_cmdstanpy(posterior=fit)
        fil_az_data = az_data.sel(chain=az.bfmi(az_data) > 0.3)
        flattened_theta = fil_az_data.stack(sample=["draw","chain"]).posterior.theta_DP_scaled.values.T
        split = np.array_split(flattened_theta, self.split_num)
        self.all_posteriors[key] = split
        kls = np.array([compute_kl_log(samples) for samples in split])
        self.all_kls[key] = kls.mean()

        return self.all_kls[key]

    def update_i(self, i, curr_div_new_eps_ratio):
        """ curr_div_new_eps_ratio = curr_eps / target_eps. If >1, mean smaller target eps, higher privacy """
        self.variances[i] = self.variances[i] * curr_div_new_eps_ratio
        self.Z[i] = (self.Z[i]- self.S[i]) * np.sqrt(curr_div_new_eps_ratio) + self.S[i]

        if self.num <= 6:
            for subset in self.power_set:
                print(subset)
                if (not subset) or (i not in subset):
                    continue
                self._compute_kl_save_posterior(subset)
            # This relies on Python 3.7+ python dictionary are guaranteed to be insertion ordered
            self.update_v_and_shapleys()
        else:
            raise NotImplementedError("recompute from scratch?")
