import copy
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, Tuple, Optional
import numpy as np
# use tqdm to show progress bar
from tqdm import tqdm
from measures.privacy import make_valid, make_private
from measures.base import GradientBase
from opacus.optimizers.optimizer import DPOptimizer


# monkey patch the step method
def step_similarity(self, closure: Optional[Callable[[], float]] = None, player=None, step=None, mode="cosine") -> Optional[float]:
    if closure is not None:
        with torch.enable_grad():
            closure()
    
    # store the original gradients
    original_grads = []
    for group in self.param_groups:
        for p in group['params']:
            original_grads.append(p.grad.data.clone().detach())

    if self.pre_step():
        if mode == "cosine":
            # first compute the cosine similarity between the original gradients and the updated gradients
            sim_with_updated_grad = 0
            sim_with_momentum = 0
            for group in self.param_groups:
                for idx, p in enumerate(group['params']):
                    # compute the cosine similarity between the original gradients and the updated gradients
                    sim_with_updated_grad += torch.cosine_similarity(original_grads[idx].view(1,-1), p.grad.data.view(1,-1))
            self.original_optimizer.step(player=player, step=step)
            # then compute the cosine similarity between the original gradients and the momentum
            for group in self.param_groups:
                for idx, p in enumerate(group['params']):
                    # compute the cosine similarity between the original gradients and the momentum
                    sim_with_momentum += torch.cosine_similarity(original_grads[idx].view(1,-1), self.state[p][f'mom_{player}'].view(1,-1))
            return sim_with_updated_grad, sim_with_momentum
        elif mode == "l2":
            # first compute the l2 similarity between the original gradients and the updated gradients
            sim_with_updated_grad = 0
            sim_with_momentum = 0
            for group in self.param_groups:
                for idx, p in enumerate(group['params']):
                    # compute the l2 similarity between the original gradients and the updated gradients
                    sim_with_updated_grad += torch.dist(original_grads[idx].view(1,-1), p.grad.data.view(1,-1)) / torch.norm(original_grads[idx].view(1,-1))
            self.original_optimizer.step(player=player, step=step)
            # then compute the l2 similarity between the original gradients and the momentum
            for group in self.param_groups:
                for idx, p in enumerate(group['params']):
                    # compute the l2 similarity between the original gradients and the momentum
                    sim_with_momentum += torch.dist(original_grads[idx].view(1,-1), self.state[p][f'mom_{player}'].view(1,-1)) / torch.norm(original_grads[idx].view(1,-1))
            return sim_with_updated_grad, sim_with_momentum
    else:
        return None


class Similarity(GradientBase):
    def run_iter(self, step: int, model: nn.Module, optimizer: torch.optim.Optimizer, permutation: np.array, mode: str = "cosine") -> np.array:
        # run one iteration of model update and get all marginal contributions
        similarities_momentum = np.zeros(self.num_players)
        similarities_updated_grad = np.zeros(self.num_players)
        model.apply(self._init_weights)
        # order torch dataset by permutation
        datas = torch.utils.data.Subset(self.datas, permutation)
        data_loader = torch.utils.data.DataLoader(datas, batch_size=1)
        for i, (player_data, labels) in enumerate(data_loader):
            # zero out model gradients
            optimizer.zero_grad()
            player = permutation[i]
            # send player data to gpu if available
            player_data = player_data.to(self.device)
            labels = labels.to(self.device)
            # train model on player data
            model_pred = model(player_data)
            model_loss = F.cross_entropy(model_pred, labels)
            # get model gradients
            model_loss.backward()
            # update model
            similarities_updated_grad[player], similarities_momentum[player] = optimizer.step_similarity(player=player, step=step, mode=mode)
        return similarities_updated_grad, similarities_momentum

    def run(self, num_iters: int, clipping_norm: Optional[float] = None, epsilon: Optional[float] = 20, delta: Optional[float] = 1e-5, mode="cosine") -> Tuple[np.array, np.array]:
        # run multiple iterations of model update and get all scores
        all_scores_updated_grad = np.zeros((1, self.num_players))
        all_scores_momentum = np.zeros((1, self.num_players))
        scores_updated_grad = np.zeros((num_iters, self.num_players))
        scores_momentum = np.zeros((num_iters, self.num_players))


        # fix numpy random seed
        np.random.seed(self.random_seed)
        model = make_valid(copy.deepcopy(self.model))
        model = model.apply(self._init_weights)
        # monkey patch the step method
        setattr(DPOptimizer, 'step_similarity', step_similarity)
        optimizer = self.get_optimizer(model, momentum=0, weight_decay=0.0, num_iters=num_iters)

        if clipping_norm > 0:
            train_data_loader = torch.utils.data.DataLoader(self.datas, batch_size=1)
            privacy_engine, model, optimizer, _ = make_private(model, optimizer, train_data_loader, num_iters, epsilon=epsilon, delta=delta, clipping_norm=clipping_norm)

        for i in tqdm(range(num_iters)):
            # get permutation
            permutation = np.random.permutation(self.num_players)
            # run one iteration
            scores_updated_grad[i], scores_momentum[i] = self.run_iter(i+1, model, optimizer, permutation, mode=mode)
            # print current epsilon and delta
            if clipping_norm > 0:
                epsilon = privacy_engine.get_epsilon(delta)
                print(f"(ε = {epsilon:.2f}, δ = {delta})")
        
        # discard the first warmup_ratio * num_iters iterations
        scores_updated_grad = scores_updated_grad[int(self.warmup_ratio * num_iters):]
        scores_momentum = scores_momentum[int(self.warmup_ratio * num_iters):]

        # get the cumulative chunks
        chunks = [num_iters]
        # get the average scores for each cumulative chunk
        for i, chunk in enumerate(chunks):
            all_scores_updated_grad[i] = np.mean(scores_updated_grad[:chunk], axis=0)
            all_scores_momentum[i] = np.mean(scores_momentum[:chunk], axis=0)

        return all_scores_updated_grad, all_scores_momentum
