import torch
import torch.nn as nn

#Define WrappedGPT class
# class WrappedGPT:
#     """
#     This class wraps a GPT layer for specific operations.
#     """

#     def __init__(self, layer, layer_id=0, layer_name="none"):
#         self.layer = layer
#         self.dev = self.layer.weight.device
#         self.rows = layer.weight.data.shape[0]
#         self.columns = layer.weight.data.shape[1]

#         self.scaler_row = torch.zeros((self.columns), device=self.dev)
#         self.nsamples = 0

#         self.layer_id = layer_id 
#         self.layer_name = layer_name

#     def add_batch(self, inp, out, weight):
#         if len(inp.shape) == 2:
#             inp = inp.unsqueeze(0)
#         tmp = inp.shape[0]
#         if isinstance(self.layer, nn.Linear):
#             if len(inp.shape) == 3:
#                 inp = inp.reshape((-1, inp.shape[-1]))
#             inp = inp.t()

#         self.scaler_row *= self.nsamples / (self.nsamples+tmp)
#         self.nsamples += tmp

#         inp = inp.type(torch.float32)
#         self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2  / self.nsamples


###############following are varp#################

class WrappedGPT:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]

        #self.scaler_row = torch.zeros(self.columns, device=self.dev)         
        self.nsamples = 0
        #self.Cloumn_sum = 0
        #self.row_sum = 0
        self.ss = 0
        self.mean_inp = torch.zeros((self.columns), device=self.dev)
        #self.mean_inps = torch.zeros((self.columns), device=self.dev)
        self.var_inp = torch.zeros((self.columns), device=self.dev)
        #self.direc_loss = torch.zeros((self.rows,self.columns), device=self.dev)
        #self.mean_yuaninp = torch.zeros((self.columns), device=self.dev)
        #self.onescalar = torch.zeros(self.columns, device=self.dev) 

        self.layer_id = layer_id 
        self.layer_name = layer_name

    def add_batch(self, inp, out,weight):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[1]
        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()
                 ####

        #self.scaler_row *= self.nsamples / (self.nsamples+tmp)

        #self.direc_loss *= self.nsamples / (self.nsamples+tmp)

        #self.nsamples += tmp

        inp = inp.type(torch.float32)
        inpk = torch.norm(inp, p=2, dim = 0)**2/inp.size(0)
        #print(inp.shape)

        #inps = inp**2 / torch.norm(inp,dim=0,p=2,keepdim=True)**2
        #inps = 1/(1-inp**2 / torch.norm(inp,dim=0,p=2,keepdim=True)**2)
        #inpss = inps**3
        inp = inp**2


        new_nsamples = self.nsamples + tmp
 
        if self.nsamples:
            new_s = self.ss*self.nsamples/new_nsamples + torch.sum(inpk)/new_nsamples
            #newmean = self.mean_yuaninp*self.nsamples/new_nsamples + torch.sum(torch.abs(inp), dim=1)/new_nsamples
            new_mean_inp = self.mean_inp*self.nsamples/new_nsamples + torch.sum(inp, dim=1)/new_nsamples
            #new_mean_inpss = self.mean_inps*self.nsamples/new_nsamples + torch.sum(inpss, dim=1)/new_nsamples
            self.var_inp = (self.nsamples-1)*self.var_inp/(new_nsamples-1) + self.nsamples*(self.mean_inp**2)/(new_nsamples-1) - new_nsamples*(new_mean_inp**2)/(new_nsamples-1) + torch.sum(inp**2, dim=1)/(new_nsamples-1)
            # self.var_inp = ((self.nsamples-1)*self.var_inp + self.nsamples*(self.mean_inp**2) - new_nsamples*(new_mean_inp**2) + torch.sum(inp**2, dim=1))/(new_nsamples-1)
        else:
            new_mean_inp = torch.sum(inp, dim=1)/new_nsamples
            #new_mean_inpss = torch.sum(inpss, dim=1)/new_nsamples
            new_s = torch.sum(inpk)/new_nsamples
            #newmean = torch.sum(torch.abs(inp), dim=1)/new_nsamples
            self.var_inp = torch.sum((inp-new_mean_inp.reshape(-1,1))**2, dim=1)/(new_nsamples)

        self.mean_inp = new_mean_inp
        #self.mean_inps = new_mean_inpss
        self.nsamples = new_nsamples
        #self.mean_yuaninp = newmean
        self.ss = new_s
        # mul = weight@inp
        # mul_norm = torch.norm(mul,p=2,dim=1).unsqueeze(0).t()
        # #fro_norm = torch.sum(mul**2)
        # # inp_norm = torch.sum(inp**2,dim=1)
        # # inner = mul@inp.t()
        # # weight_sq = weight**2
        # # delta_cos = (-weight*inner - 0.5*weight_sq*inp_norm.unsqueeze(0))/mul_norm
        # # cos_loss = torch.abs(delta_cos)
        # mul_norm = mul_norm.repeat(1,self.columns)
        # C = torch.abs(weight)*torch.norm(inp,p=2,dim=1)
        # div = torch.sqrt(mul_norm**2-C**2)
        # #cos_loss = 1-(fro_norm-C**2)/fro_norm

        # cos_loss = 1-(mul_norm-C)/div
        
        # mul_norm = 2*torch.norm(mul,dim=0)**2
        # mid = torch.sum(inp**2 / mul_norm,dim=1)
        # cos_loss=mid*(weight**2)
        #cos_loss = (C**2)/(2*(mul_norm**2))-(C**4)/(24*(mul_norm**2))

        #self.row_sum += inp.sum(dim=1).unsqueeze(0).t()
        # weight_direct = []
        # for i in range(weight.size(0)):
        #     dis = weight[i].unsqueeze(0).t()*inp
        #     cal = weight[i]@inp
        #     mid = F.normalize(cal-dis, p=2, dim=1)
        #     cal_norm = F.normalize(cal, p=2, dim=0)
        #     dis_cir = cal_norm - mid
        #     dis_cir_value = torch.norm(dis_cir, p=2, dim=1)
        #     weight_direct.append(dis_cir_value.unsqueeze(0))
        # direct = torch.cat(weight_direct,dim=0)

        #print(weight)

        ###
        

        ###
        #self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2  / self.nsamples

        #self.scaler_row = (self.mean_inp.reshape((1,-1))**2 + self.mean_inp.reshape((1,-1)) + self.var_inp.reshape((1,-1)))
        #self.onescalar = self.mean_yuaninp.reshape((1,-1))*self.ss/inp.size(0)

        #self.direc_loss += cos_loss/self.nsamples
        #self.scaler_row += self.Cloumn_sum  / self.nsamples

        #self.row_sum += self.row_sum /  self.nsamples

    def get_statistics(self):
        if self.nsamples <= 1:
            return (self.mean_inp.reshape((1,-1))**2 + self.mean_inp.reshape((1,-1)) + 1)
        else:
            return (self.mean_inp.reshape((1,-1))**2 + self.mean_inp.reshape((1,-1)) + self.var_inp.reshape((1,-1)) + 1)
        
