"""Normalization modules."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np
def div_func(x,coef_list):
    return sum([coef_list[i]*(x**i) for i in range(len(coef_list))]) 

k1 = 0.128
k2 = 1.6645
x1 = 0.3111
x2 = 343.6645
P = 1.0
a = 1.0
b= 500

def L2(x, k2=k2, x2=x2):
    return (-0.5 * k2 * (x2**(-3/2)) * x) + (1.5 * k2 * (1/(x2**0.5)))

def duble_iter(x, y0):
    return 0.0625 * x**4 * y0**9 - 0.5625 * x**3 * y0**7 + 1.6875 * x**2 * y0**5 - 2.4375 * x * y0**3 + 2.25 * y0

def newton_approximationL2short(x, num_iterations=3):
    y_initial = L2(x)
    y = y_initial
    for _ in range(num_iterations):
        y = duble_iter(x, y)
    return y
  

def initialGuess(x):
    return 2-x

def gReciprocalOfSqrt(x, n):
    b = x
    Y = initialGuess(x)
    y = Y

    for i in range(n):
        b = b*Y**2
        Y = (3-b)/2
        y = y*Y
    return y
def gReciprocalOfSqrtScaled(x, n=25, s=60000):
    x = x * (1.0/s)
    p = gReciprocalOfSqrt(x,n)
    p = (1/math.sqrt(s))*p
    return p

def calculate_b2_expanded(b0, Y0):
    return (9 * b0 * Y0**2 - 6 * b0**2 * Y0**4 + b0**3 * Y0**6) / 4

def calculate_Y2_expanded(b0, Y0):
    A = b0 * Y0**2
    return (12 - 9 * A + 6 * A**2 - A**3) / 8

def calculate_y2_expanded(b0, Y0, y0):
    return y0 * (b0 * Y0**2 - 3) * (b0 * Y0**2 * (b0 * Y0**2 - 3)**2 - 12) / 16


def eff2gReciprocalOfSqrt(x, n):
    b = x
    Y = initialGuess(x)
    y = Y
    for i in range(n):
        b1 = calculate_b2_expanded(b, Y)
        Y1 = calculate_Y2_expanded(b, Y)
        y = calculate_y2_expanded(b, Y, y)
        b = b1
        Y = Y1
    return y

num_iter = 24
scale = 100

def eff2gReciprocalOfSqrtScaled(x, n=num_iter, s=scale):
    assert n % 2 == 0
    x = x * (1.0 / s)
    p = eff2gReciprocalOfSqrt(x, n//2)
    p = (1 / math.sqrt(s)) * p
    return p

 
class CustumOnnxOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return eff2gReciprocalOfSqrtScaled(x=x)  # Placeholder for your actual operation

    @staticmethod
    def symbolic(g, input):
        # This will insert a node named 'Inverse' in the ONNX graph
        return g.op("PolyApprox", input,
                    function_s=[b"1/sqrt"],
                    approximationMethod_s=[b"goldschmidt"],
                    numIterations_i=num_iter,
                    scale_i=scale)


class CLN(nn.Module): # Custom LayerNorm
    def __init__(self, d, no_sqrt,use_poly_div, **norm_args):
        super().__init__()
        self.no_monitor = False
        self.custom_op = False
        self.no_sqrt = no_sqrt
        self.norm = nn.LayerNorm(d, **norm_args)
        self.reset_stat()
        self.var_loss = torch.Tensor([0.0])
        # Poly:
        coef_list = [1.5985493307730192, -0.655509330502505, 0.14896211711842, -0.01848781853830923, 0.0013594068282550688, -6.206701120767051e-05, 1.7727888369315731e-06, -3.000287041377021e-08, 2.3129730657364298e-10, 8.400700205394146e-13, -2.507114885539959e-14, -2.6286707102225482e-17, 2.4639297961612453e-18, 3.457034978107797e-21, -2.5696272052664017e-22, -3.02518306837854e-25, 3.127294547401911e-26, -2.2241955390352476e-28, 5.056419870954855e-31]
        self.coef_list = coef_list[:15]
        self.div_func = lambda x : sum([coef_list[i]*(x**i) for i in range(len(coef_list))])
        # Debug flags
        self.use_poly_div = use_poly_div
        self.my_ln = True
        self.debug = False

    def replace_to_poly(self, min_range=-1,max_range=1, degree=16, custom_op=False):
        self.use_poly_div = True
        self.custom_op=custom_op
        self.degree = degree
        print("--- replace LN to poly via G-ALG")
        #from poly_utils import gen_poly 
        #coef_list, _ = gen_poly(min_range=min_range, max_range=max_range, degree=degree , f = lambda x : 1/(x**0.5))
        #self.coef_list = coef_list
        #self.div_func = lambda x : sum([coef_list[i]*(x**i) for i in range(len(coef_list))])
    def get_stat(self):
        return {'min' : self.min_denom_ob, 'max' : self.max_denom_ob , 'mean' : self.mean_denom_ob}

    def update_stat(self, denom):
        if self.num_elem is None:
            self.min_denom_ob =  denom.min()
            self.max_denom_ob =  denom.max()
            self.mean_denom_ob = denom.mean()
            self.num_elem = denom.shape[0]
        else:
            self.min_denom_ob = min(self.min_denom_ob, denom.min())
            self.max_denom_ob = max(self.max_denom_ob, denom.max())
            self.mean_denom_ob =  ( self.mean_denom_ob *self.num_elem  + denom.mean()*denom.shape[0] ) * ( 1 / (self.num_elem + denom.shape[0]))
            self.num_elem = self.num_elem + denom.shape[0]

    def reset_stat(self):
        # This statitics should calculate over the entire epoch.
        # Note that while the loss computed over the batch, but statistics compute over single batch.
        self.min_denom_ob =  None
        self.max_denom_ob =  None
        self.mean_denom_ob = None
        self.num_elem = None

    def forward(self, x):
        # Handle higher dimension logic
        shape = x.shape #BLH
        L = x.shape[1]
        # if self.transposed:
        #     x = rearrange(x, 'b d ... -> b d (...)')
        # else:
        #    x = rearrange(x, 'b ... d -> b (...) d')

        # The cases of LayerNorm / no normalization are automatically handled in all cases
       
        # X is BLH
        if self.my_ln:

            mean = x.mean(dim=-1, keepdim=True)

            variance = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
            variance_eps = variance + self.norm.eps
            std = torch.sqrt(variance_eps)
            # Monitor:
            self.loss = (variance if self.no_sqrt else std).max()

            if not self.use_poly_div:
                if self.no_sqrt:
                    normalized_input = (x - mean) / variance_eps
                    if not self.no_monitor:
                        self.update_stat(variance_eps)
                else:
                    normalized_input = (x - mean) / std
                    if not self.no_monitor:
                        self.update_stat(variance_eps)
            else:
                # Numerator,denominator
                numerator = (x - mean)
                if self.custom_op:
                    denominator = CustumOnnxOp.apply((variance_eps))
                else:
                    denominator = eff2gReciprocalOfSqrtScaled(variance_eps)                
                if not self.no_monitor:
                    self.update_stat(variance_eps) # (denom represented by 1/x for fhe issues)
                normalized_input = numerator*denominator # (denom represented by 1/x for fhe issues)
                    
            x = self.norm.weight * normalized_input + self.norm.bias

        else: # For testing
            x = self.norm(x)


        x = x.view(shape)
        return x

class RBN(nn.Module):
    def __init__(self, d, **norm_args):
        super().__init__()
        self.norm = nn.BatchNorm1d(d, **norm_args)
        self.var_loss = torch.Tensor([0.0])
        self.mean_loss = torch.Tensor([0.0])
    def forward(self, x):
        B,H,L = x.shape
        # X is BHL
        if self.training:
            global_mean = self.norm.running_mean.detach()
            global_var = self.norm.running_var.detach()
            curr_mean = x.mean((0,2))
            curr_var = x.var((0,2),correction=0)
            self.var_loss =  ((global_var - curr_var).abs()).sum() *(1/B)
        x = self.norm(x)
        return x


class Normalization(nn.Module):
    def __init__(
        self,
        d,
        transposed=False, # Length dimension is -1 or -2
        _name_='layer',
        **kwargs
    ):
        super().__init__()
        self.transposed = transposed
        self._name_ = _name_

        if _name_ == 'layer':
            self.channel = True # Normalize over channel dimension
            if self.transposed:
                self.norm = TransposedLN(d, **kwargs)
            else:
                self.norm = nn.LayerNorm(d, **kwargs)
        elif _name_ == 'cln':
            self.channel = True # Normalize over channel dimension
            if self.transposed:
                #self.norm = TransposedLN(d, **kwargs)
                raise NotImplementedError("not supported currently")
            else:
                self.norm = CLN(d, no_sqrt=False, use_poly_div=False, **kwargs)
        elif _name_ == 'pcln':
            self.channel = True # Normalize over channel dimension
            if self.transposed:
                #self.norm = TransposedLN(d, **kwargs)
                raise NotImplementedError("not supported currently")
            else:
                self.norm = CLN(d, no_sqrt=False,use_poly_div=True, **kwargs)
        elif _name_ == 'clnns': #Custom LayerNorm no sqrt
            self.channel = True # Normalize over channel dimension
            if self.transposed:
                #self.norm = TransposedLN(d, **kwargs)
                raise NotImplementedError("not supported currently")
            else:
                self.norm = CLN(d, no_sqrt=True, **kwargs)
        elif _name_ == 'instance':
            self.channel = False
            norm_args = {'affine': False, 'track_running_stats': False}
            norm_args.update(kwargs)
            self.norm = nn.InstanceNorm1d(d, **norm_args) # (True, True) performs very poorly
        elif _name_ == 'batch':
            self.channel = False
            norm_args = {'affine': True, 'track_running_stats': True}
            norm_args.update(kwargs)
            self.norm = nn.BatchNorm1d(d, **norm_args)
        elif _name_ == 'rbatch':
            self.channel = False
            norm_args = {'affine': True, 'track_running_stats': True}
            norm_args.update(kwargs)
            self.norm = RBN(d, **norm_args)
        elif _name_ == 'group':
            self.channel = False
            self.norm = nn.GroupNorm(1, d, **kwargs)
        elif _name_ == 'none':
            self.channel = True
            self.norm = nn.Identity()
        else: raise NotImplementedError

    def forward(self, x):
        # Handle higher dimension logic
        shape = x.shape #BLH
 
        if self.transposed:
            x = rearrange(x, 'b d ... -> b d (...)')
        else:
            x = rearrange(x, 'b ... d -> b (...) d')

        # The cases of LayerNorm / no normalization are automatically handled in all cases
        # Instance/Batch Norm work automatically with transposed axes
        if self.channel or self.transposed:
            # X is BLH
            x = self.norm(x)
        else:
            x = x.transpose(-1, -2)
            x = self.norm(x)
            x = x.transpose(-1, -2)

        x = x.view(shape)
        return x

    def step(self, x, **kwargs):
        assert self._name_ in ["layer", "none"]
        if self.transposed: x = x.unsqueeze(-1)
        x = self.forward(x)
        if self.transposed: x = x.squeeze(-1)
        return x

class TransposedLN(nn.Module):
    """LayerNorm module over second dimension.

    Assumes shape (B, D, L), where L can be 1 or more axis.
    This is slow and a dedicated CUDA/Triton implementation shuld provide substantial end-to-end speedup.
    """
    def __init__(self, d, scalar=True):
        super().__init__()
        self.scalar = scalar
        if self.scalar:
            self.m = nn.Parameter(torch.zeros(1))
            self.s = nn.Parameter(torch.ones(1))
            setattr(self.m, "_optim", {"weight_decay": 0.0})
            setattr(self.s, "_optim", {"weight_decay": 0.0})
        else:
            self.ln = nn.LayerNorm(d)

    def forward(self, x):
        if self.scalar:
            # calc. stats over D dim / channels
            s, m = torch.std_mean(x, dim=1, unbiased=False, keepdim=True)
            y = (self.s/s) * (x-m+self.m)
        else:
            # move channel to last axis, apply layer_norm, then move channel back to second axis
            _x = self.ln(rearrange(x, 'b d ... -> b ... d'))
            y = rearrange(_x, 'b ... d -> b d ...')
        return y

class TSNormalization(nn.Module):

    def __init__(self, method, horizon):
        super().__init__()

        self.method = method
        self.horizon = horizon


    def forward(self, x):
        # x must be BLD
        if self.method == 'mean':
            self.scale = x.abs()[:, :-self.horizon].mean(dim=1)[:, None, :]
            return x / self.scale
        elif self.method == 'last':
            self.scale = x.abs()[:, -self.horizon-1][:, None, :]
            return x / self.scale
        return x

class TSInverseNormalization(nn.Module):

    def __init__(self, method, normalizer):
        super().__init__()

        self.method = method
        self.normalizer = normalizer

    def forward(self, x):
        if self.method == 'mean' or self.method == 'last':
            return x * self.normalizer.scale
        return x

class ReversibleInstanceNorm1dInput(nn.Module):
    def __init__(self, d, transposed=False):
        super().__init__()
        # BLD if transpoed is False, otherwise BDL
        self.transposed = transposed
        self.norm = nn.InstanceNorm1d(d, affine=True, track_running_stats=False)

    def forward(self, x):
        # Means, stds
        if not self.transposed:
            x = x.transpose(-1, -2)

        self.s, self.m = torch.std_mean(x, dim=-1, unbiased=False, keepdim=True)
        self.s += 1e-4

        x = (x - self.m) / self.s
        # x = self.norm.weight.unsqueeze(-1) * x + self.norm.bias.unsqueeze(-1)

        if not self.transposed:
            return x.transpose(-1, -2)
        return x

class ReversibleInstanceNorm1dOutput(nn.Module):

    def __init__(self, norm_input):
        super().__init__()
        self.transposed = norm_input.transposed
        self.weight = norm_input.norm.weight
        self.bias = norm_input.norm.bias
        self.norm_input = norm_input

    def forward(self, x):
        if not self.transposed:
            x = x.transpose(-1, -2)

        # x = (x - self.bias.unsqueeze(-1))/self.weight.unsqueeze(-1)
        x = x * self.norm_input.s + self.norm_input.m

        if not self.transposed:
            return x.transpose(-1, -2)
        return x
