import numpy as np
import matplotlib.pyplot as plt
import torch

import os
# Polynomials from https://arxiv.org/pdf/2105.10879.pdf
# (Precise Approximation of Convolutional Neural Networks for Homomorphically Encrypted Data)
p7Aco = [3.60471572275560, 
7.30445164958251, 
-5.05471704202722, 
-3.46825871108659,
1.16564665409095,
5.98596518298826,
-6.54298492839531,
-3.18755225906466]

p7Bco = [-9.46491402344260,
2.40085652217597,
6.41744632725342,
-2.63125454261783,
-7.25338564676814,
1.54912674773593,
2.06916466421812,
-3.31172956504304
]
p7Apo = [-36,0,-35,1,-34,1,-35, 1]
p7Bpo = [-49,0,-48,0,-48,0,-48,-1]
class CPReLUR7(torch.nn.Module):
    def __init__(self, range_val=20 ,compute_in_64=True):
        super(CPReLUR7, self).__init__()
        self.r = range_val
        self.one_div_r = 1/range_val
        self.rA = 1.0
        self.rB = 1.0
        self.p7Aco = p7Aco
        self.p7Bco = p7Bco
        self.p7Apo = p7Apo
        self.p7Bpo = p7Bpo
        self.compute_in_64 = compute_in_64
        print("create poly on: [-", range_val,range_val,"]", ",degree=", 7 ,"type:" ,"relu-minixmax low degree 7X7")
    def forward(self, x):  # compue polynomials on high-precision

        if self.compute_in_64:
            dtype = x.dtype
            x = x.to(torch.float64)
        x = x * self.one_div_r
        xx = x
        x = self.rA * sum(
            [(c * (10.0 ** p) * ((x / self.rA) ** i)) for i, (c, p) in enumerate(zip(self.p7Aco, self.p7Apo))])
        x = self.rB * sum(
             [(c * (10.0 ** p) * ((x / self.rB) ** i)) for i, (c, p) in enumerate(zip(self.p7Bco, self.p7Bpo))])
        out = 0.5 * (xx + x * xx)
        out = out * self.r
        if self.compute_in_64:
            out = out.to(dtype)
        return out


if __name__ == '__main__':
    layer = CPReLUR7()
    x = torch.FloatTensor(2,3).uniform_(-5, 5)
    print(x)
    print(layer(x))
    B = 5

    # Create x values for plotting
    x = torch.linspace(-B, B, 1000000)

    # Create ReLU function
    relu = torch.nn.ReLU()

    # Create CPRELUR with degree 3
    cp_relu = CPReLUR7(B)

    # Apply ReLU and CPRELUR to x
    y_relu = relu(x)
    y_cp_relu = cp_relu(x)

    # Plot the results
    plt.plot(x.numpy(), y_relu.numpy(), label='ReLU')
    plt.plot(x.numpy(), y_cp_relu.detach().numpy(), label='CPRELUR')
    plt.legend()
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('ReLU vs CPRELUR')
    plt.savefig("minimax relu approximation")
    plt.show()

    plt.figure()
    # Plot the error
    abs = (y_relu-y_cp_relu.detach()).abs()
    plt.plot(x.numpy(), abs.numpy(), label='absolute error')
   
    plt.legend()
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('ReLU vs CPRELUR')
    plt.savefig("minimax relu err")
    plt.show()