import os
import sys
sys.path.append(os.getcwd() + '/linreg/')
sys.path.append(os.getcwd() + '/logreg/')

import numpy as np
import copy
from operator import itemgetter
from run_methods import run_methods
from kl_estimation import compute_kl_gibbs, compute_kl_log
from scipy.optimize import root_scalar
from joblib import Parallel, delayed
from scipy.stats import sem

from abc import ABC, abstractmethod
import arviz as az
from cmdstanpy import CmdStanModel
from log_generate_data import generate_data_config

class AbstractRewardProtocol(ABC):
    def __init__(self, valuation, p): # p is rho
        self.valuation = valuation
        self.num = valuation.num
        self.p = p

        self.shap_ratio = self.valuation.get_reward_ratio(p)
        self.target_rewards = self.valuation.get_reward(p)
        print("Target rewards", self.target_rewards)

        self.rewarded_posterior_i = dict()

    def get_prior_kls(self):
        return self.valuation.each_kls

    def get_rewarded_kls(self):
        return self.target_rewards

    @abstractmethod
    def _generate_reward_for_i(self, i):
        pass

    def solve_all(self):
        for i in range(self.num):
            self._generate_reward_for_i(i)

class RewardProtocol(AbstractRewardProtocol):
    def __init__(self, valuation, p, sensitivity): # p is rho
        super().__init__(valuation, p)

        self.rewarded_k_i = np.zeros(self.num)
        self.k_to_r = []

    def _learn_k_to_r(self, k):
        print("k:", k)
        v = self.valuation

        posteriors = run_methods(v.data_prior_params, v.model_prior_params, tuple(v.S), tuple(v.Z), v.sigma_DP, v.nPoint, v.nSample, v.DP_method, [v.inf_method], k=k, chains=v.chains)
        posteriors['prior'] = v.prior

        kls = 0
        for dict_key in posteriors.keys():
            if dict_key != 'prior':
                curr_kl = compute_kl_gibbs(posteriors, dict_key, mode=v.mode, model_prior_params=v.model_prior_params)
                kls += curr_kl
                self.k_to_r.append([k, curr_kl])
        mean_kl = kls / v.n_dict_keys
        print("k, kl:", k, mean_kl)

        return mean_kl


    def _solve_k_for_i(self, i, bounds=(0.01, 0.9999)):
        print("Solving for party {} to achieve {}".format(i, self.target_rewards[i]))
        def wrapper(k):
            return self._learn_k_to_r(k**2) - self.target_rewards[i]
        sol = root_scalar(wrapper, method='brentq', bracket=bounds, xtol=1e-2)
        print(sol)
        root = sol.root
        print("Party {} solved root for k".format(i), root)
        return root**2

    def _generate_reward_for_i(self, i, bounds=(0.01, 0.9999)):
        if self.shap_ratio[i] > 0.99:
            self.rewarded_k_i[i] = 1.
            self.rewarded_posterior_i[i] = self.valuation.all_posteriors[''.join(str(x) for x in list(range(self.num)))]
            return self.rewarded_posterior_i[i]

        k_root = self._solve_k_for_i(i, bounds=bounds)
        self.rewarded_k_i[i] = k_root

        v = self.valuation
        posterior = run_methods(v.data_prior_params, v.model_prior_params, tuple(v.S),
                                    tuple(v.Z), v.sigma_DP, v.nPoint, v.nSample, v.DP_method, [v.inf_method], k=k_root, chains=v.chains)
        posterior["prior"] = v.prior
        self.rewarded_posterior_i[i] = posterior
        return posterior

class LogRewardProtocol(AbstractRewardProtocol):
    def __init__(self, valuation, p): # p is rho
        super().__init__(valuation, p)

        self.rewarded_k_i = np.zeros(self.num)
        self.k_to_r = []

    def _learn_k_to_r(self, k):
        v = self.valuation

        new_config = generate_data_config(v.N,
                                        v.d,
                                        v.num,
                                        v.l2_norm, v.pass_object,
                                        v.Z, v.variances, k=k)

        fit = v.model.sample(data=new_config,
                             iter_warmup=v.burn_in, iter_sampling=v.nsamples, chains=v.chains,
                             adapt_delta=0.86, show_progress=False)

        az_data = az.from_cmdstanpy(posterior=fit)
        fil_az_data = az_data.sel(chain=az.bfmi(az_data) > 0.3)
        flattened_theta = fil_az_data.stack(sample=["draw","chain"]).posterior.theta_DP_scaled.values.T
        split = np.array_split(flattened_theta, v.split_num)
        kls = np.array([compute_kl_log(samples) for samples in split])
        self.k_to_r.extend([(k, kl) for kl in kls])

        return kls.mean()

    def _solve_k_for_i(self, i, bounds=(0.02, 1.)):
        print("Solving for party {} to achieve {}".format(i, self.target_rewards[i]))
        def wrapper(k):
            return self._learn_k_to_r(k**2) - self.target_rewards[i]
        sol = root_scalar(wrapper, method='brentq', bracket=bounds, xtol=1e-2)
        root = sol.root
        print("Party {} solved root for k".format(i), root**2)
        return root**2


    def _generate_reward_for_i(self, i, bounds=(0.02, 1.)):
        if self.shap_ratio[i] > 0.999:
            self.rewarded_k_i[i] = 1.
            self.rewarded_posterior_i[i] = self.valuation.all_posteriors[''.join(str(x) for x in list(range(self.num)))]
            return self.rewarded_posterior_i[i]

        k_root = self._solve_k_for_i(i, bounds=bounds)
        self.rewarded_k_i[i] = k_root

        v = self.valuation

        new_config = generate_data_config(v.N,
                                v.d,
                                v.num,
                                v.l2_norm, v.pass_object,
                                v.Z, v.variances, k = k_root)

        fit = v.model.sample(data=new_config,
                           iter_warmup=v.burn_in, iter_sampling=v.nsamples, chains=v.chains,
                            adapt_delta=0.86, show_progress=False)

        az_data = az.from_cmdstanpy(posterior=fit)
        fil_az_data = az_data.sel(chain=az.bfmi(az_data) > 0.3)
        flattened_theta = fil_az_data.stack(sample=["draw","chain"]).posterior.theta_DP_scaled.values.T
        split = np.array_split(flattened_theta, v.split_num)


        self.rewarded_posterior_i[i] = split
        return split
