import numpy as np
import matplotlib.pyplot as plt
import torch

import os
# Polynomials from https://arxiv.org/pdf/2105.10879.pdf
# (Precise Approximation of Convolutional Neural Networks for Homomorphically Encrypted Data)
p29co = [6.72874968716530,
5.31755497689391,
5.68199275801086,
-3.54371531531577,
-1.35187813155454,
1.84122441329140,
1.05531766289589,
-6.55386830146253,
-4.14266518871760,
1.63878335428060,
9.63097361166316,
-2.95386237048226,
-1.44556688409360,
3.90806423362418,
1.47265013864485,
-3.83496739165131,
-1.04728251169615,
2.79960654766517,
5.26108728786276,
-1.51286231886692,
-1.86083902222546,
5.96160139340009,
4.53644110199468,
-1.66321739302958,
-7.25782287655313,
3.10988369739884,
6.85800520634485,
-3.49349374506190,
-2.89849811206637,
1.78142156956495]

p29po = [-48,
0.0,
-46,
1.0,
-44,
2,
-43,
2,
-43,
3,
-43,
3,
-42,
3,
-42,
3,
-42,
3,
-43,
3,
-43,
2,
-44,
2,
-45,
1,
-46,
0.0,
-47,
-1]



p27co = [-9.27991756967991,
1.68285511926011,
8.32408114686671,
-3.39811750495659,
-1.27756566625811,
2.79069998793847,
7.70152836729131,
-1.13514151573790,
-2.41159918805990,
2.66230010283745,
4.48807056213874,
-3.93840328661975,
-5.34821622972202,
3.87884230348060,
4.25722502798559,
-2.62395303844988,
-2.31146624263347,
1.23656207016532,
8.58571463533718,
-4.05336460089999,
-2.14564940301255,
9.06042880951087,
3.44803367899992,
-1.31687649208288,
-3.21717059336602,
1.12176079033623,
1.32425600403443,
-4.24938020467471
]

p27po = [-46,
1,
-44,
2,
-42,
3,
-42,
4,
-41,
4,
-41,
4,
-41,
4,
-41,
4,
-41,
4,
-42,
3,
-42,
2,
-43,
2,
-44,
1,
-45,
-1]

p15co = [-3.38572283433492,
2.49052143193754,
7.67064296707865,
-6.82383057582430,
-1.33318527258859,
6.80942845390599,
9.19464568002043,
-3.12507100017105,
-3.02547883089949,
7.47659388363757,
5.02426027571770,
-9.65046838475839,
-4.05931240321443,
6.36977923778246,
1.26671427827897,
-1.68602621347190]


p15po =[-47,
1,
-45,
2,
-43,
3,
-43,
4,
-42,
4,
-42,
4,
-42,
4,
-42,
4]

class CPReLUR(torch.nn.Module):
    def __init__(self, range_val=20 ,compute_in_64=True):
        super(CPReLUR, self).__init__()
        self.r = range_val
        self.one_div_r = 1/range_val
        self.r15 = 1.0
        self.r27 = 1.0
        self.r29 = 1.0
        self.p15c = p15co
        self.p15p = p15po
        self.p27c = p27co
        self.p27p = p27po
        self.p29c = p29co
        self.p29p = p29po
        self.compute_in_64 = compute_in_64
        print("create poly on: [-", range_val,range_val,"]", ",degree=", 27 ,"type:" ,"relu-minixmax")
    def forward(self, x):  # compue polynomials on high-precision

        if self.compute_in_64:
            dtype = x.dtype
            x = x.to(torch.float64)
        x = x * self.one_div_r
        xx = x
        x = self.r15 * sum(
            [(c * (10.0 ** p) * ((x / self.r15) ** i)) for i, (c, p) in enumerate(zip(self.p15c, self.p15p))])
        x = self.r27 * sum(
             [(c * (10.0 ** p) * ((x / self.r27) ** i)) for i, (c, p) in enumerate(zip(self.p27c, self.p27p))])
        x = self.r29 * sum(
             [(c * (10.0 ** p) * ((x / self.r29) ** i)) for i, (c, p) in enumerate(zip(self.p29c, self.p29p))])
        out = 0.5 * (xx + x * xx)
        out = out * self.r
        if self.compute_in_64:
            out = out.to(dtype)
        return out


if __name__ == '__main__':
    layer = CPReLUR()
    x = torch.FloatTensor(2,3).uniform_(-5, 5)
    print(x)
    print(layer(x))
    B = 5

    # Create x values for plotting
    x = torch.linspace(-B, B, 1000000)

    # Create ReLU function
    relu = torch.nn.ReLU()

    # Create CPRELUR with degree 3
    cp_relu = CPReLUR(B)

    # Apply ReLU and CPRELUR to x
    y_relu = relu(x)
    y_cp_relu = cp_relu(x)

    # Plot the results
    plt.plot(x.numpy(), y_relu.numpy(), label='ReLU')
    plt.plot(x.numpy(), y_cp_relu.detach().numpy(), label='CPRELUR')
    plt.legend()
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('ReLU vs CPRELUR')
    plt.savefig("minimax relu approximation")
    plt.show()

    plt.figure()
    # Plot the error
    abs = (y_relu-y_cp_relu.detach()).abs()
    plt.plot(x.numpy(), abs.numpy(), label='absolute error')
   
    plt.legend()
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('ReLU vs CPRELUR')
    plt.savefig("minimax relu err")
    plt.show()