from abc import ABC, abstractmethod
import numpy as np
try:
    import torch
except ImportError:
    torch = None
from math import fmod


class SecretSharing(ABC):
    def __init__(self):
        pass

    @abstractmethod
    def secret_split(self, secret):
        pass

    @abstractmethod
    def secret_reconstruct(self, secret_seq):
        pass


class AdditiveSecretSharing(SecretSharing):
    """
    AdditiveSecretSharing class, which can split a number into frames and recover it by summing up
    """
    def __init__(self, shared_party_num, size=60):
        super(SecretSharing, self).__init__()
        assert shared_party_num > 1, "AdditiveSecretSharing require shared_party_num > 1"
        self.shared_party_num = shared_party_num
        self.maximum = 2**size
        self.mod_number = 2 * self.maximum + 1
        self.epsilon = 1e8
        self.mod_funs = np.vectorize(lambda x: x % self.mod_number)
        self.float2fixedpoint = np.vectorize(self._float2fixedpoint)
        self.fixedpoint2float = np.vectorize(self._fixedpoint2float)

    def secret_split(self, secret):
        """
        To split the secret into frames according to the shared_party_num
        """
        if isinstance(secret, dict):
            secret_list = [dict() for _ in range(self.shared_party_num)]
            for key in secret:
                for idx, each in enumerate(self.secret_split(secret[key])):
                    secret_list[idx][key] = each
            return secret_list

        if isinstance(secret, list) or isinstance(secret, np.ndarray):
            secret = np.asarray(secret)
            shape = [self.shared_party_num - 1] + list(secret.shape)
        elif isinstance(secret, torch.Tensor):
            secret = secret.numpy()
            shape = [self.shared_party_num - 1] + list(secret.shape)
        else:
            shape = [self.shared_party_num - 1]

        secret = self.float2fixedpoint(secret)
        secret_seq = np.random.randint(low=0, high=self.mod_number, size=shape)
        #last_seq = self.mod_funs(secret - self.mod_funs(np.sum(secret_seq, axis=0)))
        last_seq = self.mod_funs(secret -
                                 self.mod_funs(np.sum(secret_seq, axis=0)))

        secret_seq = np.append(secret_seq,
                               np.expand_dims(last_seq, axis=0),
                               axis=0)
        return secret_seq

    def secret_reconstruct(self, secret_seq):
        """
        To recover the secret
        """
        assert len(secret_seq) == self.shared_party_num
        merge_model = secret_seq[0].copy()
        if isinstance(merge_model, dict):
            for key in merge_model:
                for idx in range(len(secret_seq)):
                    if idx == 0:
                        merge_model[key] = secret_seq[idx][key]
                    else:
                        merge_model[key] += secret_seq[idx][key]
                merge_model[key] = self.fixedpoint2float(merge_model[key])

        return merge_model

    def _float2fixedpoint(self, x):
        x = round(x * self.epsilon, 0)
        assert abs(x) < self.maximum
        return x % self.mod_number

    def _fixedpoint2float(self, x):
        x = x % self.mod_number
        if x > self.maximum:
            return -1 * (self.mod_number - x) / self.epsilon
        else:
            return x / self.epsilon
