import numpy as np
import matplotlib.pyplot as plt
import torch

import os
# Polynomials from https://arxiv.org/pdf/2105.10879.pdf
# (Precise Approximation of Convolutional Neural Networks for Homomorphically Encrypted Data)
p15Aco = [
    3.85169741234183,
    1.80966285718807,
    -4.59730416916377,
    -4.34038703274886,
    7.96299160375690,
    4.15497103545696,
    -5.28977110396316,
    -1.86846943613149,
    1.67219551148917,
    4.41657177889329,
    -2.69777424798506,
    -5.65527928983401,
    2.14124591383569,
    3.71156122725781,
    -6.61722455927198,
    -9.78241933892781
]
p15Bco = [-1.04501074063854,
3.79753323360856,
4.22842209818016,
-1.17718157771192, 
-2.25571113936639, 
2.49771086678346, 
4.42462875106862, 
-3.15238841603993, 
-4.13554194411645, 
2.37294863126722, 
2.00060158783094, 
-1.04331800195923,
-4.86041132712796,
2.46743976260838,
4.71256214052049,
-2.42130100247617]

p15Apo = [
    -44,
    1,
    -42,
    2,
    -41,
    3,
    -40,
    4,
    -39,
    4,
    -39,
    4,
    -39,
    4,
    -40,
    3]
p15Bpo = [-46,
0,
-45,
1,
-44,
1,
-44,
1,
-44,
1,
-44,
1,
-45,
0,
-46,
-1]
class CPReLUR15(torch.nn.Module):
    def __init__(self, range_val=20 ,compute_in_64=True):
        super(CPReLUR15, self).__init__()
        self.r = range_val
        self.one_div_r = 1/range_val
        self.rA = 1.0
        self.rB = 1.0
        self.p15Aco = p15Aco
        self.p15Bco = p15Bco
        self.p15Apo = p15Apo
        self.p15Bpo = p15Bpo
        self.compute_in_64 = compute_in_64
        print("create poly on: [-", range_val,range_val,"]", ",degree=", 15 ,"type:" ,"relu-minixmax low degree 15X15")
    def forward(self, x):  # compue polynomials on high-precision

        if self.compute_in_64:
            dtype = x.dtype
            x = x.to(torch.float64)
        x = x * self.one_div_r
        xx = x
        x = self.rA * sum(
            [(c * (10.0 ** p) * ((x / self.rA) ** i)) for i, (c, p) in enumerate(zip(self.p15Aco, self.p15Apo))])
        x = self.rB * sum(
             [(c * (10.0 ** p) * ((x / self.rB) ** i)) for i, (c, p) in enumerate(zip(self.p15Bco, self.p15Bpo))])
        out = 0.5 * (xx + x * xx)
        out = out * self.r
        if self.compute_in_64:
            out = out.to(dtype)
        return out


if __name__ == '__main__':
    layer = CPReLUR15()
    x = torch.FloatTensor(2,3).uniform_(-5, 5)
    print(x)
    print(layer(x))
    B = 5

    # Create x values for plotting
    x = torch.linspace(-B, B, 1000000)

    # Create ReLU function
    relu = torch.nn.ReLU()

    # Create CPRELUR with degree 3
    cp_relu = CPReLUR15(B)

    # Apply ReLU and CPRELUR to x
    y_relu = relu(x)
    y_cp_relu = cp_relu(x)

    # Plot the results
    plt.plot(x.numpy(), y_relu.numpy(), label='ReLU')
    plt.plot(x.numpy(), y_cp_relu.detach().numpy(), label='CPRELUR')
    plt.legend()
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('ReLU vs CPRELUR')
    plt.savefig("minimax relu approximation")
    plt.show()

    plt.figure()
    # Plot the error
    abs = (y_relu-y_cp_relu.detach()).abs()
    plt.plot(x.numpy(), abs.numpy(), label='absolute error')
   
    plt.legend()
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('ReLU vs CPRELUR')
    plt.savefig("minimax relu err")
    plt.show()