import torchvision
import torch
import torch.nn as nn
import numpy as np 
from src.models.nn.activation import SquaredReLU,GELU2
import sys
degree=14
def gen_poly(min_range=-7, max_range=7, degree=degree , f = torch.nn.ReLU()):
    print("create poly on: [", min_range,max_range,"]", ",degree=", degree ,"type:" ,f)
    x_s = np.linspace(min_range, max_range, 100000)
    y_s = f(torch.from_numpy(x_s)).numpy()
    z = np.polyfit(x_s, y_s, degree)[::-1]
    coeffs = torch.from_numpy(z.copy())

    return  coeffs, degree

def replace_activations(model):
    cnt = 0
    for n,module in model.named_children():
        if len(list(module.children())) > 0 and (type(module) != polyAct) and (module.__class__.__name__ !='SquaredReLU') and (module.__class__.__name__ !='GELU2'):
            cnt = cnt + replace_activations(module)
        if module.__class__.__name__ =='SquaredReLU':
            setattr(model, n, polyAct(act_type ="relu2"))
            cnt = cnt+1
        if module.__class__.__name__ =='GELU2':
            setattr(model, n, polyAct(act_type ="gelu2"))
            cnt = cnt+1
        if isinstance(module, torch.nn.ReLU):
            setattr(model, n, polyAct(act_type ="relu"))
            cnt = cnt+1
        if isinstance(module, torch.nn.GELU):
            setattr(model, n, polyAct(act_type ="gelu"))
            cnt = cnt+1
    return cnt


def range_lst_to_global(lst):
    global_min =0
    global_max = 0 
    for (mi,mx) in lst:
        global_min = min(global_min,mi)
        global_max = max(global_max,mx)
    return global_min,global_max



class polyAct(nn.Module):
    def __init__(self , act_type ="gelu", p_norm=1):
        super(polyAct, self).__init__()
        self.act_type = act_type
        if act_type=="relu":
            self.act = nn.ReLU()
        elif act_type=="relu2":
            self.act = SquaredReLU()
        elif act_type=="gelu":
            self.act = nn.GELU()
        elif act_type=="gelu2":
            self.act = GELU2()
        else:
            raise ValueError("unsupported activation in polyAct")
        self.min_val = None # For visualization
        self.max_val = None # For visualization
        self.curr_loss = 0
        self.p_norm = p_norm
        self.use_poly = False
        self.fhe = False
        self.poly = None

    def __repr__(self):
        return "poly-" + self.act_type +'| is_poly-' + str(self.use_poly)
    def __str__(self):
        return "poly-" + self.act_type +'| is_poly-' + str(self.use_poly)
    def forward(self,x):
        if self.fhe: return sum([self.poly[i]*x.pow(i) for i in range(0,self.degree+1)])
        curr_x_max = x.max().detach().item()
        curr_x_min = x.min().detach().item()
        # Visualization:
        if self.min_val is None:
            self.min_val = curr_x_min
        else:
            self.min_val = min(self.min_val, curr_x_min)

        if self.max_val is None:
            self.max_val = curr_x_max
        else:
            self.max_val = max(self.max_val, curr_x_max)
        # Loss:
        self.curr_loss =  abs(x).max() if self.p_norm == 1 else (x**2).max()

        if not self.use_poly:
            o = self.act(x)
            return o
        else:
            # relu Minimax:
            if self.act_type =="relu":
                return self.minimax_relu(x).to(x.dtype)
            o = sum([self.poly[i]*x.pow(i) for i in range(0,self.degree+1)])
            o = o.to(x.dtype)
            return o 
       
    def get_loss(self):
        return self.curr_loss
    
    def replace_to_poly(self, degree=18 ,range_val=6):
        if self.use_poly == True:
            return False
        else: 
            delta = 0
            if range_val is not None:
                min_val = -1.0*range_val
                max_val = range_val
            
            #relu Minimax:
            if self.act_type=="relu":
                print("create minixmax relu 15X15")
                from src.models.nn.polynomials.minimax_relu_low_degree15X15 import CPReLUR15 as CPReLUR
                #from src.models.nn.polynomials.minimax_relu import CPReLUR
                self.minimax_relu = CPReLUR(range_val=range_val ,compute_in_64=True).cuda()
                self.poly = None
                self.degree = None
                self.use_poly = True    
            else:
        
                poly ,degree = gen_poly(min_range=min_val-delta, max_range=max_val+delta, degree=degree , f = self.act)
                self.poly = poly
                self.degree = degree
                self.use_poly = True    
        return True 

    def set_random_poly(self, degree=degree):
        self.poly = torch.randn(degree+1)
        self.degree = degree
        self.use_poly = True    
        self.fhe = True

    def get_min_max(self):
        return (self.min_val, self.max_val)

    def reset_ranges(self):
        self.min_val = None
        self.max_val = None





