import math
import torch
import torch.nn as nn

import crypten
import crypten.communicator as comm
import crypten.nn as cnn

from torch import Size
from typing import Union, List, Tuple
import numbers
from transformers import AutoModelForCausalLM


def copy_weight(model):
    sd = AutoModelForCausalLM.from_pretrained('gpt2', torch_dtype=torch.float32).state_dict()
    with crypten.no_grad(), torch.no_grad():
        for name, param in model.named_parameters():
            if any(sub in name for sub in ['c_fc', 'c_proj', 'c_attn', 'c_proj', 'wte']) and 'weight' in name:
                param.copy_(sd.pop('transformer.'+name).transpose(0,1).to(torch.float32))
            elif 'lm_head' in name:
                param.copy_(sd.pop(name).to(torch.float32))
            else:
                param.copy_(sd.pop('transformer.'+name).to(torch.float32))
    return model


def encrypt_tensor(input):
    """Encrypt data tensor for multi-party setting"""
    # get rank of current process
    rank = comm.get().get_rank()
    # get world size
    world_size = comm.get().get_world_size()
    assert world_size  == 2
    src_id = 1
    if rank == src_id:
        input_upd = input.cuda()
    else:
        input_upd = torch.empty(input.size()).cuda()
    private_input = crypten.cryptensor(input_upd, src=src_id)
    return private_input


def encrypt_model(model, modelFunc, config):
    rank = comm.get().get_rank()
    
    if rank == 0:
        model = copy_weight(model)
        model_upd = model.cuda()
    else:
        model_upd = modelFunc(config).cuda()

    private_model = model_upd.encrypt(src=0)
    return private_model


class MPCIdentity(cnn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, input):
        return input
    

class MPCLinear(cnn.Module):
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__()

        # initialize model parameters:
        pytorch_module = torch.nn.Linear(in_channels, out_channels, bias=bias)
        self.register_parameter("weight", pytorch_module.weight)
        if bias:
            self.register_parameter("bias", pytorch_module.bias)
            
    
    def forward(self, x):
        output = x.matmul(self.weight.transpose(0,1))
        if hasattr(self, 'bias'):
            output = output + self.bias
        return output


class PumaExp(cnn.Module):
    def __init__(self, iters=8):
        super().__init__()
        self.iters = iters

    def forward(self, x):
        zone1_mask = (1-(2**self.iters+x).sign()) / 2 #14
        result = (1 + x.div(2**self.iters)) * (1-zone1_mask)
        for _ in range(self.iters):
            result = result.square()
        return result
        

class PumaReciprocal(cnn.Module):
    def __init__(self, exp_iters=8, nr_iters=10):
        super().__init__()
        self.nr_iters = nr_iters
        self.exp = PumaExp(iters=exp_iters)

    def forward(self, x):
        result = 3 * self.exp(1 - 2 * x) + 0.003
        for _ in range(self.nr_iters):
            result += result - result.square().mul_(x)
        return result
        

class PumaInvsqrt(cnn.Module):
    def __init__(self, iters=4, exp_iters=8):
        super().__init__()
        self.iters = iters
        self.exp = PumaExp(iters=exp_iters)

    def forward(self, x):
        y = self.exp(x.div(2).add(0.2).neg()).mul(2.2).add(0.2)
        y -= x.div(1024)
        for _ in range(self.iters):
            y = y.mul_(3 - x * y.square()).div_(2)
        return y


class PumaSigmoid(cnn.Module):
    def __init__(self):
        super().__init__()
        self.exp = PumaExp()
        self.reciprocal = PumaReciprocal()

    def forward(self, x):
        ltz = x._ltz()
        sign = 1 - 2 * ltz

        pos_input = x.mul(sign)
        denominator = self.exp(pos_input.neg()).add(1)

        pos_output = self.reciprocal(denominator)

        result = pos_output.where(1 - ltz, 1 - pos_output)
        return result
        

class PumaTanh(cnn.Module):
    def __init__(self):
        super().__init__()
        self.sigmoid = PumaSigmoid()

    def forward(self, x):
        return self.sigmoid(x.mul(2)).mul(2).sub(1)


class PumaErf(cnn.Module):
    def __init__(self):
        super().__init__()
        self.tanh = PumaTanh()
        self.pow = cnn.Pow()
        
    def forward(self, x):
        output = (x + self.pow((x,3)).mul(11/123)).mul(2.0 / math.sqrt(math.pi))
        return self.tanh(output)


class PumaGeLU(cnn.Module):
    def __init__(self):
        super().__init__()
        self.erf = PumaErf()
        self.half = torch.tensor([0.5]).item()
        self.one = torch.tensor([1.0]).item()

    def forward(self, x):
        max_mask = (1-(4-x).sign()) / 2
        min_mask = (1-(4+x).sign()) / 2
        approx_mask = (1-max_mask)*(1-min_mask)
        result = approx_mask*(self.half * x * (self.one + self.erf(x*approx_mask*(1 / math.sqrt(2.0))))) + max_mask*x
        return result


class PumaSoftmax(cnn.Module):
    def __init__(self, exp_iters=8, reci_exp_iters=8, nr_iters=10, dim=-1):
        super().__init__()
        self.dim = dim
        self.exp = PumaExp(iters=exp_iters)
        self.reciprocal = PumaReciprocal(exp_iters=reci_exp_iters, nr_iters=nr_iters)

    def forward(self, x):
        x_max = x.max(dim=self.dim, keepdim=True)[0]
        x = x - x_max
        x_exp = self.exp(x)
        x_sum = self.reciprocal(x_exp.sum(dim=self.dim, keepdim=True))
        return x_exp.mul_(x_sum)
        

_shape_t = Union[int, List[int], Size]
class PumaLayerNorm(cnn.Module):
    __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
    normalized_shape: Tuple[int, ...]
    eps: float
    elementwise_affine: bool

    def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
                 bias: bool = True, device=None, dtype=None):
        # super().__init__(normalized_shape)
        
        super().__init__()
        pytorch_module = torch.nn.BatchNorm1d(normalized_shape)
        for param in ["weight", "bias"]:
            self.register_parameter(param, getattr(pytorch_module, param))
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        self.invsqrt = PumaInvsqrt()


    def forward(self, x, bit_range=3):
        var = (x - x.mean(dim=-1, keepdim=True)).square().mean(dim=-1, keepdim=True)

        # Scaling
        min_mask = (1+(2**(-bit_range)-var).sign()) / 2
        min_mask_2 = (1+(2**(-bit_range*2)-var).sign()) / 2
        min_mask_3 = (1+(2**(-bit_range*3)-var).sign()) / 2

        max_mask = (1-(2**(bit_range)-var).sign()) / 2
        max_mask_2 = (1-(2**(bit_range*2)-var).sign()) / 2
        max_mask_3 = (1-(2**(bit_range*3)-var).sign()) / 2

        x = x*(1-min_mask)*(1-max_mask) \
            +x*min_mask_3*(2**(bit_range*3)) +x*(1-min_mask_3)*min_mask_2*(2**(bit_range*2)) +x*(1-min_mask_2)*min_mask*(2**(bit_range)) \
            +x*max_mask_3*(2**(-bit_range*3)) +x*(1-max_mask_3)*max_mask_2*(2**(-bit_range*2)) +x*(1-max_mask_2)*max_mask*(2**(-bit_range))
        
        mean = x.mean(dim=-1, keepdim=True)
        var = (x - mean).square().mean(dim=-1, keepdim=True)

        x = (x - mean) * self.invsqrt(var + self.eps)
        x = x * self.weight.reshape(1,1,-1) + self.bias.reshape(1,1,-1)

        return x