import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
import numpy as np
# use tqdm to show progress bar
from tqdm import tqdm
from measures.privacy import make_valid, make_private
from measures.base import GradientBase
import scipy

# Beta Shapley
def get_beta_constant(a, b):
    """
    the second argument (b; beta) should be integer in this function
    """
    beta_fct_value=1/a
    for i in range(1,b):
        beta_fct_value=beta_fct_value*(i/(a+i))
    return beta_fct_value


def get_dataset_permutation(dataset_size: int, perm: np.array) -> np.array:
    num_players = len(perm)
    partitions = [list(range(i * dataset_size, (i+1) * dataset_size)) for i in range(0, num_players)]
    # permute the partitions according to permutation in the argument
    res = []
    for i in perm:
        res += partitions[i]
    return np.array(res)



class GShapley(GradientBase):
    def get_banzhaf_coeff(self, prev_coalition_size: int):
        return scipy.special.comb(self.num_players - 1, prev_coalition_size)

    def get_beta_coeff(self, prev_coalition_size: int, alpha=4, beta=1):
            n = self.num_players
            j = prev_coalition_size 
            return get_beta_constant(j+alpha, n-j+beta-1) / get_beta_constant(j+1, n-j)

    def get_shapley_coeff(self, prev_coalition_size: int):
        return 1.0

    def run_iter(self, step: int, original_weight, model: nn.Module, optimizer: torch.optim.Optimizer, perm: np.array, permutation: np.array) -> np.array:
        # run one iteration of model update and get all marginal contributions
        marginal_contributions = np.zeros(self.num_players)
        # reinitialize model every iteration; except for resnet18 as it is quite hard to converge
        if not self.args.model.startswith("resnet") or step % 10 == 0:
            model.load_state_dict(original_weight)
            model = model.apply(self._init_weights)
        # order torch dataset by permutation
        datas = torch.utils.data.Subset(self.datas, permutation)
        data_loader = torch.utils.data.DataLoader(datas, batch_size=self.dataset_size)
        last_utility = 0
        # initialize the last loss
        with torch.no_grad():
            model_pred = model(self.test_data)
            if self.args.utility is None or self.args.utility == "negated_loss":
                last_utility = -F.cross_entropy(model_pred, self.test_labels).item()
            elif self.args.utility == "accuracy":
                last_utility = (model_pred.argmax(dim=1) == self.test_labels).float().mean().item()
            else:
                raise NotImplementedError(f"Utility {self.args.utility} not implemented")
        for i, (player_data, labels) in enumerate(data_loader):
            # zero out model gradients
            optimizer.zero_grad()
            player = perm[i]
            marginal_contributions[player] = -last_utility
            # send player data to gpu if available
            player_data = player_data.to(self.device)
            labels = labels.to(self.device)
            # train model on player data
            model_pred = model(player_data)
            model_loss = F.cross_entropy(model_pred, labels)
            # get model gradients
            model_loss.backward()
            # update model
            if self.use_momentum:
                optimizer.step_momentum(player=player, step=step)
            else:
                optimizer.step()
            # eval on test data
            with torch.no_grad():
                model_pred = model(self.test_data)
                if self.args.utility is None or self.args.utility == "negated_loss":
                    utility = -F.cross_entropy(model_pred, self.test_labels).item()
                elif self.args.utility == "accuracy":
                    utility = (model_pred.argmax(dim=1) == self.test_labels).float().mean().item()
                else:
                    raise NotImplementedError(f"Utility {self.args.utility} not implemented")
                marginal_contributions[player] += utility
                last_utility = utility
        return marginal_contributions

    def run(self, num_iters: int, clipping_norm: Optional[float] = None, epsilon: Optional[float] = 20, delta: Optional[float] = 1e-5) -> Tuple[np.array, np.array]:
        # run multiple iterations of model update and get all marginal contributions
        scores_shap = np.zeros((num_iters, self.num_players))
        scores_banzhaf = np.zeros((num_iters, self.num_players))
        scores_beta41 = np.zeros((num_iters, self.num_players))
        scores_beta161 = np.zeros((num_iters, self.num_players))

        # fix numpy random seed
        np.random.seed(self.random_seed)
        model = make_valid(copy.deepcopy(self.model))
        model = model.apply(self._init_weights)
        optimizer = self.get_optimizer(model, momentum=0, weight_decay=0.0, num_iters=num_iters)

        if clipping_norm > 0:
            train_data_loader = torch.utils.data.DataLoader(self.datas, batch_size=self.dataset_size)
            privacy_engine, model, optimizer, _ = make_private(model, optimizer, train_data_loader, num_iters, epsilon=epsilon, delta=delta, clipping_norm=clipping_norm)

        original_weight = copy.deepcopy(model.state_dict())

        for i in tqdm(range(num_iters)):
            # get permutation
            perm = np.random.permutation(self.num_players)
            permutation = get_dataset_permutation(self.dataset_size, perm)
            # permute datas with each partition of size self.dataset_size
            # run one iteration
            cur_res = self.run_iter(i+1, original_weight, model, optimizer, perm, permutation)
            scores_shap[i] = np.array(cur_res)
            if i > 10:
                # get average
                average_shap = np.mean(scores_shap[int(0 * i):i+1], axis=0)
                # rank the scores
                rank = np.argsort(average_shap)
                target_idx = np.arange(int(self.num_players * 0.3))
                # find how many in the lowest 30% are in target_idx
                num_in_target = len(np.intersect1d(rank[:int(self.num_players * 0.3)], target_idx))
                # print the ratio
                print(f"ratio of lowest 30% in target: {num_in_target / len(target_idx)}")

            scores_banzhaf[i] = np.array(cur_res)
            scores_beta41[i] = np.array(cur_res)
            scores_beta161[i] = np.array(cur_res)
            # update by get_coeff
            for j, player in enumerate(perm):
                scores_shap[i][player] *= self.get_shapley_coeff(j)
                scores_banzhaf[i][player] *= self.get_banzhaf_coeff(j)
                scores_beta41[i][player] *= self.get_beta_coeff(j, 4, 1)
                scores_beta161[i][player] *= self.get_beta_coeff(j, 16, 1)
            # print current epsilon and delta
            if clipping_norm > 0:
                epsilon = privacy_engine.get_epsilon(delta)
                print(f"(ε = {epsilon:.2f}, δ = {delta})")
        
        # store scores
        return scores_shap, scores_banzhaf, scores_beta41, scores_beta161

    def run_loo(self, clipping_norm: Optional[float] = None, epsilon: Optional[float] = 20, delta: Optional[float] = 1e-5) -> Tuple[np.array, np.array]:
        num_iters = self.num_players
        # run multiple iterations of model update and get all marginal contributions
        scores_loo = np.zeros((num_iters, self.num_players))

        # fix numpy random seed
        np.random.seed(self.random_seed)
        model = make_valid(copy.deepcopy(self.model))
        model = model.apply(self._init_weights)
        optimizer = self.get_optimizer(model, momentum=0, weight_decay=0.0, num_iters=num_iters)

        if clipping_norm > 0:
            train_data_loader = torch.utils.data.DataLoader(self.datas, batch_size=self.dataset_size)
            privacy_engine, model, optimizer, _ = make_private(model, optimizer, train_data_loader, num_iters, epsilon=epsilon, delta=delta, clipping_norm=clipping_norm)

        original_weight = copy.deepcopy(model.state_dict())

        for i in tqdm(range(self.num_players)):
            # get permutation
            perm = np.random.permutation(self.num_players)
            # move the i-th player to the end
            perm = np.concatenate([perm[perm != i], perm[perm == i]])
            permutation = get_dataset_permutation(self.dataset_size, perm)
            # permute datas with each partition of size self.dataset_size
            # run one iteration
            cur_res = self.run_iter(i+1, original_weight, model, optimizer, perm, permutation)
            scores_loo[i][i] = np.array(cur_res)[i]
            if i > 10:
                # get average
                average_shap = np.mean(scores_loo[int(0 * i):i+1], axis=0)
                # rank the scores
                rank = np.argsort(average_shap)
                target_idx = np.arange(int(self.num_players * 0.3))
                # find how many in the lowest 30% are in target_idx
                num_in_target = len(np.intersect1d(rank[:int(self.num_players * 0.3)], target_idx))
                # print the ratio
                print(f"ratio of lowest 30% in target: {num_in_target / len(target_idx)}")

            # print current epsilon and delta
            if clipping_norm > 0:
                epsilon = privacy_engine.get_epsilon(delta)
                print(f"(ε = {epsilon:.2f}, δ = {delta})")
        
        # store scores
        return scores_loo

