import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
import numpy as np
# use tqdm to show progress bar
from tqdm import tqdm
from measures.privacy import make_valid, make_private
from measures.base import GradientBase
import scipy

# Beta Shapley
def get_beta_constant(a, b):
    """
    the second argument (b; beta) should be integer in this function
    """
    beta_fct_value=1/a
    for i in range(1,b):
        beta_fct_value=beta_fct_value*(i/(a+i))
    return beta_fct_value


def get_dataset_permutation(dataset_size: int, perm: np.array) -> np.array:
    num_players = len(perm)
    partitions = [list(range(i * dataset_size, (i+1) * dataset_size)) for i in range(0, num_players)]
    # permute the partitions according to permutation in the argument
    res = []
    for i in perm:
        res += partitions[i]
    return np.array(res)



class FL(GradientBase):
    def get_banzhaf_coeff(self, prev_coalition_size: int):
        return scipy.special.comb(self.num_players - 1, prev_coalition_size)

    def get_beta_coeff(self, prev_coalition_size: int, alpha=4, beta=1):
            n = self.num_players
            j = prev_coalition_size 
            return get_beta_constant(j+alpha, n-j+beta-1) / get_beta_constant(j+1, n-j)

    def get_shapley_coeff(self, prev_coalition_size: int):
        return 1.0
    
    def get_gradient(self, global_model, models, optimizers, schedulers, step):
        for player in range(self.num_players):
            # copy global model to player model
            model = models[player]
            optimizer = optimizers[player]
            for p1, p2 in zip(model.parameters(), global_model.parameters()):
                p1.data = p2.data.clone()
            # zero out model gradients
            optimizer.zero_grad()
            # train model on player data
            indices = list(range(self.dataset_size * player, self.dataset_size * (player+1)))
            subset = torch.utils.data.Subset(self.datas, indices)
            data_loader = torch.utils.data.DataLoader(subset, batch_size=len(subset))
            for player_data, labels in data_loader:
                # send player data to gpu if available
                player_data = player_data.to(self.device)
                labels = labels.to(self.device)
                # train model on player data
                model_pred = model(player_data)
                model_loss = F.cross_entropy(model_pred, labels)
                # get model gradients
                model_loss.backward()
                # get gradients
                if self.use_momentum:
                    optimizer.step_momentum(player=player, step=step)
                else:
                    optimizer.step()
            schedulers[player].step()
        print(f"LR: {schedulers[0].get_last_lr()[0]}")

    def run_iter(self, step: int, global_model: nn.Module, global_optimizer, models: List[nn.Module], optimizers: List[torch.optim.Optimizer]) -> np.array:
        # run one iteration of model update and get all marginal contributions
        marginal_contributions = np.zeros((self.num_players, 100))

        for i in range(100):
            cur_model = copy.deepcopy(global_model)
            # sample a permutation
            permutation = np.random.permutation(self.num_players)
            last_utility = 0
            # initialize the last loss
            with torch.no_grad():
                model_pred = cur_model(self.test_data)
                # compute test accuracy
                last_utility = (model_pred.argmax(dim=1) == self.test_labels).sum().item() / len(self.test_labels)

            for j in range(self.num_players):
                # zero out model gradients
                player = permutation[j]
                marginal_contributions[player][i] = -last_utility
                for p1, p2 in zip(cur_model.parameters(), models[player].parameters()):
                    # if requires gradient
                    if p1.requires_grad:
                        p1.data = (p1.data * j + p2.data.clone()) / (j+1)
                # eval on test data
                with torch.no_grad():
                    model_pred = cur_model(self.test_data)
                    # compute test accuracy
                    utility = (model_pred.argmax(dim=1) == self.test_labels).sum().item() / len(self.test_labels)
                    marginal_contributions[player][i] += utility
                    last_utility = utility

        print(f"test acc: {utility}")
        # use average over 50 iterations
        marginal_contributions = np.mean(marginal_contributions, axis=1)
        # normalize marginal contributions
        # marginal_contributions /= np.sqrt(optimizers[0].param_groups[0]['lr'])
        return marginal_contributions, cur_model
                

    def run(self, num_iters: int, clipping_norm: Optional[float] = None, epsilon: Optional[float] = 20, delta: Optional[float] = 1e-5) -> Tuple[np.array, np.array]:
        # run multiple iterations of model update and get all marginal contributions
        scores_shap = np.zeros((num_iters, self.num_players))

        # fix numpy random seed
        np.random.seed(self.random_seed)
        global_model = make_valid(copy.deepcopy(self.model))
        global_model = global_model.apply(self._init_weights)
        global_optimizer = torch.optim.Adam(global_model.parameters(), lr=self.lr)
        models = [None for _ in range(self.num_players)]
        optimizers = [None for _ in range(self.num_players)]
        schedulers = [None for _ in range(self.num_players)]
        # one model and one optimizer for each player
        for player in range(self.num_players):
            models[player] = copy.deepcopy(global_model)
            optimizers[player] = self.get_optimizer(models[player], momentum=0, weight_decay=0.0, num_iters=num_iters)
            schedulers[player] = torch.optim.lr_scheduler.CosineAnnealingLR(optimizers[player], T_max=num_iters, eta_min=self.lr / 20.0)

        if clipping_norm > 0:
            # privatize each player's data
            for player in range(self.num_players):
                data_indice = list(range(self.dataset_size * player, self.dataset_size * (player+1)))
                subset = torch.utils.data.Subset(self.datas, data_indice)
                train_data_loader = torch.utils.data.DataLoader(subset, batch_size=self.dataset_size)
                privacy_engine, models[player], optimizers[player], _ = make_private(models[player], optimizers[player], train_data_loader, num_iters, epsilon=epsilon, delta=delta, clipping_norm=clipping_norm)

        for i in tqdm(range(num_iters)):
            self.get_gradient(global_model, models, optimizers, schedulers, i+1)
            print(f"Finished gradient update @ iteration {i+1}")
            cur_res, cur_model = self.run_iter(i+1, global_model, global_optimizer, models, optimizers)
            print(f"Finished marginal contribution computation @ iteration {i+1}; scores: {cur_res}")
            # update
            global_model.load_state_dict(cur_model.state_dict())
            scores_shap[i] = np.array(cur_res)
            if i >= 0:
                # get average
                average_shap = np.mean(scores_shap[int(0 * i):i+1], axis=0)
                # rank the scores
                rank = np.argsort(average_shap)
                target_idx = np.arange(int(self.num_players * 0.3))
                # find how many in the lowest 30% are in target_idx
                num_in_target = len(np.intersect1d(rank[:int(self.num_players * 0.3)], target_idx))
                # print the ratio
                if clipping_norm < 0:
                    name = "no DP"
                elif self.use_momentum:
                    name = "corr"
                else:
                    name = "iid"
                print(f"{name}: ratio of lowest 30% in target: {num_in_target / len(target_idx)}")

            # print current epsilon and delta
            if clipping_norm > 0:
                epsilon = privacy_engine.get_epsilon(delta)
                print(f"(ε = {epsilon:.2f}, δ = {delta})")
        
        # store scores
        return scores_shap

