import torch
import torch.nn as nn


# Define WrappedGPT class
class WrappedGPT:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.out = torch.zeros((self.rows, 2048), device=self.dev)
        self.inp = torch.zeros((2048, self.columns), device=self.dev)
        self.out_cov = torch.zeros((self.rows, self.rows), device=self.dev)  
        self.out_mean = torch.zeros((self.rows, 1), device=self.dev)
        self.inp_mean = torch.zeros((self.columns, 1), device=self.dev)

        self.out_avgcov = torch.zeros((self.rows, self.rows), device=self.dev)
        self.out_avgmean = torch.zeros((self.rows, 1), device=self.dev)
        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):

        if len(out.shape) == 2:
            out = out.unsqueeze(0)

        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)

        out = out.squeeze(0).T.type(torch.float32)
        inp = inp.squeeze(0).T.type(torch.float32)

        out_cov = torch.cov(out, correction=1)
        out_mean = torch.mean(out, dim=-1).unsqueeze(1)
        inp_mean = torch.mean(inp, dim=-1).unsqueeze(1)
        # print("layer_name", self.layer_name)
        # print("out_mean shape", out_mean.shape)
        # print("inp_mean shape", inp_mean.shape)
        # print("self inp_mean shape", self.inp_mean.shape)

        tmp = out.shape[1]
        inp_tmp = inp.shape[1]

        mean = (self.nsamples * self.out_mean + tmp * out_mean) / (self.nsamples + tmp)
        inp_mean = (self.nsamples * self.inp_mean + inp_tmp * inp_mean) / (self.nsamples + inp_tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples* (self.out_mean-mean)@(self.out_mean-mean).T + (out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples/(self.nsamples+tmp)*(out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        cov = ((self.nsamples - 1) * self.out_cov + (tmp - 1) * out_cov + self.nsamples * tmp / (
                    self.nsamples + tmp) * (out_mean - mean) @ (out_mean - mean).T) / (
                          self.nsamples + tmp - 1)  # unbias
        self.out_mean = mean
        self.inp_mean = inp_mean
        self.out_cov = cov
        self.nsamples += tmp
        # if self.nsamples % (32 * tmp) == 0:
        #     tmp = 1
        #     self.out_avgcov = (self.avg_step * self.out_avgcov + tmp * self.out_cov) / (self.avg_step + tmp)
        #     self.out_avgmean = (self.avg_step * self.out_avgmean + tmp * self.out_mean) / (self.avg_step + tmp)
        #     self.out_cov = torch.zeros((self.rows, self.rows), device=self.dev)
        #     self.out_mean = torch.zeros((self.rows, 1), device=self.dev)
        #     self.avg_step += tmp
        #     self.nsamples = 0


class WrappedGPT_avg:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.out = torch.zeros((self.rows, 2048), device=self.dev)
        self.inp = torch.zeros((2048, self.columns), device=self.dev)
        self.out_cov = torch.zeros((self.rows, self.rows), device=self.dev)  
        self.out_mean = torch.zeros((self.rows, 1), device=self.dev)
        self.inp_mean = torch.zeros((self.columns, 1), device=self.dev)
        self.inp_list = []
        self.out_list = []

        self.out_avgcov = torch.zeros((self.rows, self.rows), device=self.dev)
        self.out_avgmean = torch.zeros((self.rows, 1), device=self.dev)
        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):

        if len(out.shape) == 2:
            out = out.unsqueeze(0)

        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)

        out = out.squeeze(0).T.type(torch.float32)
        inp = inp.squeeze(0).T.type(torch.float32)

        # import pdb;pdb.set_trace()
        out_cov = torch.cov(out, correction=1)
        out_mean = torch.mean(out, dim=-1).unsqueeze(1)
        inp_mean = torch.mean(inp, dim=-1).unsqueeze(1)
        # self.inp_list.append(inp_mean)
        # print("layer_name", self.layer_name)
        # print("out_mean shape", out_mean.shape)
        # print("inp_mean shape", inp_mean.shape)
        # print("self inp_mean shape", self.inp_mean.shape)

        tmp = out.shape[1]
        inp_tmp = inp.shape[1]

        mean = (self.nsamples * self.out_mean + tmp * out_mean) / (self.nsamples + tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples* (self.out_mean-mean)@(self.out_mean-mean).T + (out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples/(self.nsamples+tmp)*(out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        cov = ((self.nsamples - 1) * self.out_cov + (tmp - 1) * out_cov + self.nsamples * tmp / (
                    self.nsamples + tmp) * (out_mean - mean) @ (out_mean - mean).T) / (
                          self.nsamples + tmp - 1)  # unbias
        self.out_mean = mean
        self.inp_mean = inp_mean
        self.out_cov = cov
        self.nsamples += tmp

class WrappedGPT_avg_add:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.out = torch.zeros((self.rows, 2048), device=self.dev)
        self.inp = torch.zeros((2048, self.columns), device=self.dev)
        self.out_cov = torch.zeros((self.rows, self.rows), device=self.dev)  
        self.out_mean = torch.zeros((self.rows, 1), device=self.dev)
        self.inp_mean = torch.zeros((self.columns, 1), device=self.dev)
        self.inp_list = []
        self.out_list = []

        self.out_avgcov = torch.zeros((self.rows, self.rows), device=self.dev)
        self.out_avgmean = torch.zeros((self.rows, 1), device=self.dev)
        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):

        if len(out.shape) == 2:
            out = out.unsqueeze(0)

        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)

        out = out.squeeze(0).T.type(torch.float32)

        self.out_cov += out @ out.T




class WrappedGPT_after_rope:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none", num_head=32, head_dim=128, device=None):
        self.layer = layer
        self.dev = device
        self.hidden_size = num_head * head_dim

        self.out_cov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)  
        self.inp_cov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)  
        self.out_mean = torch.zeros((self.hidden_size, 1), device=self.dev)
        self.inp_mean = torch.zeros((self.hidden_size, 1), device=self.dev)

        self.out_avgcov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)  
        self.out_avgmean = torch.zeros((self.hidden_size, 1), device=self.dev)

        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        # import pdb;pdb.set_trace()
        if len(out.shape) == 3:
            out = out.unsqueeze(0)

        if len(inp.shape) == 3:
            inp = inp.unsqueeze(0)
        
        out = out.squeeze(0)
        inp = inp.squeeze(0)
        head_num, seq_len, head_dim = out.shape
        out = out.transpose(0, 1).contiguous().reshape(seq_len, head_num * head_dim).T.type(torch.float32)
        inp = inp.transpose(0, 1).contiguous().reshape(seq_len, head_num * head_dim).T.type(torch.float32)

        # out = out.squeeze(0).T.type(torch.float32)
        # inp = inp.squeeze(0).T.type(torch.float32)
        out_cov = torch.cov(out, correction=1)
        inp_cov = torch.cov(inp, correction=1)
        out_mean = torch.mean(out, dim=-1).unsqueeze(1)
        inp_mean = torch.mean(inp, dim=-1).unsqueeze(1)

        tmp = out.shape[1]

        mean = (self.nsamples * self.out_mean + tmp * out_mean) / (self.nsamples + tmp)
        mean_inp = (self.nsamples * self.inp_mean + tmp * inp_mean) / (self.nsamples + tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples* (self.out_mean-mean)@(self.out_mean-mean).T + (out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples/(self.nsamples+tmp)*(out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        cov = ((self.nsamples - 1) * self.out_cov + (tmp - 1) * out_cov + self.nsamples * tmp / (
                    self.nsamples + tmp) * (out_mean - mean) @ (out_mean - mean).T) / (
                          self.nsamples + tmp - 1)  # unbias
        
        inp_cov = ((self.nsamples - 1) * self.inp_cov + (tmp - 1) * inp_cov + self.nsamples * tmp / (
                    self.nsamples + tmp) * (inp_mean - mean_inp) @ (inp_mean - mean_inp).T) / (
                          self.nsamples + tmp - 1)  # unbias
        self.out_mean = mean
        self.inp_mean = mean_inp
        self.out_cov = cov
        self.inp_cov = inp_cov
        self.nsamples += tmp
        # if self.nsamples % (32 * tmp) == 0:
        #     tmp = 1
        #     self.out_avgcov = (self.avg_step * self.out_avgcov + tmp * self.out_cov) / (self.avg_step + tmp)
        #     self.out_avgmean = (self.avg_step * self.out_avgmean + tmp * self.out_mean) / (self.avg_step + tmp)
        #     self.out_cov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)
        #     self.out_mean = torch.zeros((self.hidden_size, 1), device=self.dev)
        #     self.avg_step += 1
        #     self.nsamples = 0

class WrappedGPT_after_rope_perhead:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none", num_head=32, head_dim=128, device=None):
        self.layer = layer
        self.dev = device
        self.hidden_size = num_head * head_dim
        self.num_head = num_head
        self.head_dim = 128

        self.out_cov = torch.zeros((num_head, head_dim, head_dim), device=self.dev)  
        self.inp_cov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)  
        self.out_mean = torch.zeros((num_head, head_dim, 1), device=self.dev)
        self.inp_mean = torch.zeros((self.hidden_size, 1), device=self.dev)

        # self.out_avgcov = torch.zeros((hidden_size, hidden_size), device=self.dev)  
        # self.out_avgmean = torch.zeros((hidden_size, 1), device=self.dev)

        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        # import pdb;pdb.set_trace()
        if len(out.shape) == 3:
            out = out.unsqueeze(0)

        if len(inp.shape) == 3:
            inp = inp.unsqueeze(0)
        
        out = out.squeeze(0)
        inp = inp.squeeze(0)
        head_num, seq_len, head_dim = out.shape
        inp = inp.transpose(0, 1).contiguous().reshape(seq_len, head_num * head_dim).T.type(torch.float32)
        out = out.transpose(1, 2).type(torch.float32) # num_head, head_dim, seq_len

        # out = out.squeeze(0).T.type(torch.float32)
        # inp = inp.squeeze(0).T.type(torch.float32)
        tmp = inp.shape[1]

        inp_cov = torch.cov(inp, correction=1)
        inp_mean = torch.mean(inp, dim=-1).unsqueeze(1)
        mean_inp = (self.nsamples * self.inp_mean + tmp * inp_mean) / (self.nsamples + tmp)
        inp_cov = ((self.nsamples - 1) * self.inp_cov + (tmp - 1) * inp_cov + self.nsamples * tmp / (
                    self.nsamples + tmp) * (inp_mean - mean_inp) @ (inp_mean - mean_inp).T) / (
                        self.nsamples + tmp - 1)  # unbias

        self.inp_mean = mean_inp
        self.inp_cov = inp_cov
        

        for i in range(self.num_head):

            out_cov = torch.cov(out[i], correction=1)
            out_mean = torch.mean(out[i], dim=-1).unsqueeze(1)

            mean = (self.nsamples * self.out_mean[i] + tmp * out_mean) / (self.nsamples + tmp)
            # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples* (self.out_mean-mean)@(self.out_mean-mean).T + (out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
            # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples/(self.nsamples+tmp)*(out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
            cov = ((self.nsamples - 1) * self.out_cov[i] + (tmp - 1) * out_cov + self.nsamples * tmp / (
                        self.nsamples + tmp) * (out_mean - mean) @ (out_mean - mean).T) / (
                            self.nsamples + tmp - 1)  # unbias
            self.out_mean[i] = mean
            self.out_cov[i] = cov
        self.nsamples += tmp
        # if self.nsamples % (32 * tmp) == 0:
        #     tmp = 1
        #     self.out_avgcov = (self.avg_step * self.out_avgcov + tmp * self.out_cov) / (self.avg_step + tmp)
        #     self.out_avgmean = (self.avg_step * self.out_avgmean + tmp * self.out_mean) / (self.avg_step + tmp)
        #     self.out_cov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)
        #     self.out_mean = torch.zeros((self.hidden_size, 1), device=self.dev)
        #     self.avg_step += 1
        #     self.nsamples = 0

class WrappedGPT_rope_perhead:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none", num_head=32, head_dim=128, device=None):
        self.layer = layer
        self.dev = device
        self.hidden_size = num_head * head_dim
        self.num_head = num_head
        self.head_dim = 128

        self.out_cov = torch.zeros((num_head, head_dim, head_dim), device=self.dev)  
        self.inp_cov = torch.zeros((num_head, head_dim, head_dim), device=self.dev)  
        self.out_mean = torch.zeros((num_head, head_dim, 1), device=self.dev)
        self.inp_mean = torch.zeros((num_head, head_dim, 1), device=self.dev)

        # self.out_avgcov = torch.zeros((hidden_size, hidden_size), device=self.dev)  
        # self.out_avgmean = torch.zeros((hidden_size, 1), device=self.dev)

        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        # import pdb;pdb.set_trace()
        if len(out.shape) == 3:
            out = out.unsqueeze(0)

        if len(inp.shape) == 3:
            inp = inp.unsqueeze(0)
        
        out = out.squeeze(0)
        inp = inp.squeeze(0)
        head_num, seq_len, head_dim = out.shape
        inp = inp.transpose(1, 2).type(torch.float32) # num_head, head_dim, seq_len
        out = out.transpose(1, 2).type(torch.float32) # num_head, head_dim, seq_len

        # out = out.squeeze(0).T.type(torch.float32)
        # inp = inp.squeeze(0).T.type(torch.float32)
        tmp = inp.shape[1]

        for i in range(self.num_head):
            
            out_cov = torch.cov(out[i], correction=1)
            inp_cov = torch.cov(inp[i], correction=1)
            out_mean = torch.mean(out[i], dim=-1).unsqueeze(1)
            inp_mean = torch.mean(inp[i], dim=-1).unsqueeze(1)

            mean = (self.nsamples * self.out_mean[i] + tmp * out_mean) / (self.nsamples + tmp)
            mean_inp = (self.nsamples * self.inp_mean[i] + tmp * inp_mean) / (self.nsamples + tmp)
            # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples* (self.out_mean-mean)@(self.out_mean-mean).T + (out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
            # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples/(self.nsamples+tmp)*(out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
            cov = ((self.nsamples - 1) * self.out_cov[i] + (tmp - 1) * out_cov + self.nsamples * tmp / (
                        self.nsamples + tmp) * (out_mean - mean) @ (out_mean - mean).T) / (
                            self.nsamples + tmp - 1)  # unbias
            cov_inp = ((self.nsamples - 1) * self.inp_cov[i] + (tmp - 1) * inp_cov + self.nsamples * tmp / (
                        self.nsamples + tmp) * (inp_mean - mean_inp) @ (inp_mean - mean_inp).T) / (
                            self.nsamples + tmp - 1)  # unbias
            self.out_mean[i] = mean
            self.out_cov[i] = cov
            self.inp_cov[i] = cov_inp
            self.inp_mean[i] = mean_inp
        self.nsamples += tmp
        # if self.nsamples % (32 * tmp) == 0:
        #     tmp = 1
        #     self.out_avgcov = (self.avg_step * self.out_avgcov + tmp * self.out_cov) / (self.avg_step + tmp)
        #     self.out_avgmean = (self.avg_step * self.out_avgmean + tmp * self.out_mean) / (self.avg_step + tmp)
        #     self.out_cov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)
        #     self.out_mean = torch.zeros((self.hidden_size, 1), device=self.dev)
        #     self.avg_step += 1
        #     self.nsamples = 0


class WrappedGPT_after_rope_gqa:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none", num_head=32, num_kv_head=8, head_dim=128, device=None):
        self.layer = layer
        self.dev = device
        self.hidden_size = num_head * head_dim
        self.num_head = num_head
        self.num_kv_head = num_kv_head
        self.kv_group = num_head // num_kv_head
        self.head_dim = 128

        self.out_cov = torch.zeros((num_kv_head, self.kv_group * head_dim, self.kv_group * head_dim), device=self.dev)  
        self.inp_cov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)  
        self.out_mean = torch.zeros((num_kv_head, self.kv_group * head_dim, 1), device=self.dev)
        self.inp_mean = torch.zeros((self.hidden_size, 1), device=self.dev)

        # self.out_avgcov = torch.zeros((hidden_size, hidden_size), device=self.dev)  
        # self.out_avgmean = torch.zeros((hidden_size, 1), device=self.dev)

        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        # import pdb;pdb.set_trace()
        if len(out.shape) == 3:
            out = out.unsqueeze(0)

        if len(inp.shape) == 3:
            inp = inp.unsqueeze(0)
        
        out = out.squeeze(0)
        inp = inp.squeeze(0)
        head_num, seq_len, head_dim = out.shape
        inp = inp.transpose(0, 1).contiguous().reshape(seq_len, head_num * head_dim).T.type(torch.float32)
        out = out.transpose(1, 2).reshape(self.num_kv_head, self.num_head // self.num_kv_head * head_dim, seq_len).type(torch.float32) # num_head, head_dim, seq_len

        # out = out.squeeze(0).T.type(torch.float32)
        # inp = inp.squeeze(0).T.type(torch.float32)
        tmp = inp.shape[1]

        inp_cov = torch.cov(inp, correction=1)
        inp_mean = torch.mean(inp, dim=-1).unsqueeze(1)
        mean_inp = (self.nsamples * self.inp_mean + tmp * inp_mean) / (self.nsamples + tmp)
        inp_cov = ((self.nsamples - 1) * self.inp_cov + (tmp - 1) * inp_cov + self.nsamples * tmp / (
                    self.nsamples + tmp) * (inp_mean - mean_inp) @ (inp_mean - mean_inp).T) / (
                        self.nsamples + tmp - 1)  # unbias

        self.inp_mean = mean_inp
        self.inp_cov = inp_cov
        

        for i in range(self.num_kv_head):
            out_cov = torch.cov(out[i], correction=1)
            out_mean = torch.mean(out[i], dim=-1).unsqueeze(1)

            mean = (self.nsamples * self.out_mean[i] + tmp * out_mean) / (self.nsamples + tmp)
            # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples* (self.out_mean-mean)@(self.out_mean-mean).T + (out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
            # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples/(self.nsamples+tmp)*(out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
            cov = ((self.nsamples - 1) * self.out_cov[i] + (tmp - 1) * out_cov + self.nsamples * tmp / (
                        self.nsamples + tmp) * (out_mean - mean) @ (out_mean - mean).T) / (
                            self.nsamples + tmp - 1)  # unbias
            self.out_mean[i] = mean
            self.out_cov[i] = cov
        self.nsamples += tmp
        # if self.nsamples % (32 * tmp) == 0:
        #     tmp = 1
        #     self.out_avgcov = (self.avg_step * self.out_avgcov + tmp * self.out_cov) / (self.avg_step + tmp)
        #     self.out_avgmean = (self.avg_step * self.out_avgmean + tmp * self.out_mean) / (self.avg_step + tmp)
        #     self.out_cov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)
        #     self.out_mean = torch.zeros((self.hidden_size, 1), device=self.dev)
        #     self.avg_step += 1
        #     self.nsamples = 0


class WrappedGPT_after_rope_with_group:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none", hidden_size=4096, device=None):
        self.layer = layer
        self.dev = device
        self.hidden_size = hidden_size

        self.group_len = 128
        self.group_num = 4096 // self.group_len
        self.out_cov = torch.zeros((self.group_num, hidden_size, hidden_size), device=self.dev)  
        self.out_mean = torch.zeros((self.group_num, hidden_size, 1), device=self.dev)


        self.nsamples = 0
        self.avg_step = 0


        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        # import pdb;pdb.set_trace()
        if len(out.shape) == 3:
            out = out.unsqueeze(0)

        if len(inp.shape) == 3:
            inp = inp.unsqueeze(0)
        
        out = out.squeeze(0)
        head_num, seq_len, head_dim = out.shape
        out = out.transpose(0, 1).reshape(seq_len // self.group_len, self.group_len, head_num * head_dim).transpose(1, 2).type(torch.float32)

        # out = out.squeeze(0).T.type(torch.float32)
        # inp = inp.squeeze(0).T.type(torch.float32)
        group_num = out.shape[0]
        for i in range(group_num):
            out_cov = torch.cov(out[i], correction=1)
            out_mean = torch.mean(out[i], dim=-1).unsqueeze(1)

            tmp = out.shape[1]

            mean = (self.nsamples * self.out_mean[i] + tmp * out_mean) / (self.nsamples + tmp)
            # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples* (self.out_mean-mean)@(self.out_mean-mean).T + (out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
            # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples/(self.nsamples+tmp)*(out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
            cov = ((self.nsamples - 1) * self.out_cov[i] + (tmp - 1) * out_cov + self.nsamples * tmp / (
                        self.nsamples + tmp) * (out_mean - mean) @ (out_mean - mean).T) / (
                            self.nsamples + tmp - 1)  # unbias
        
            self.out_mean[i] = mean
            self.out_cov[i] = cov
        self.nsamples += tmp

class WrappedGPT_after_rope_add:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none", hidden_size=4096, device=None):
        self.layer = layer
        self.dev = device
        self.hidden_size = hidden_size

        self.out_cov = torch.zeros((hidden_size, hidden_size), device=self.dev)  
        self.inp_cov = torch.zeros((hidden_size, hidden_size), device=self.dev)  
        self.out_mean = torch.zeros((hidden_size, 1), device=self.dev)
        self.inp_mean = torch.zeros((hidden_size, 1), device=self.dev)

        self.out_avgcov = torch.zeros((hidden_size, hidden_size), device=self.dev)  
        self.out_avgmean = torch.zeros((hidden_size, 1), device=self.dev)

        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        # import pdb;pdb.set_trace()
        if len(out.shape) == 3:
            out = out.unsqueeze(0)

        if len(inp.shape) == 3:
            inp = inp.unsqueeze(0)
        
        out = out.squeeze(0)
        inp = inp.squeeze(0)
        head_num, seq_len, head_dim = out.shape
        out = out.transpose(0, 1).contiguous().reshape(seq_len, head_num * head_dim).T.type(torch.float32)
        inp = inp.transpose(0, 1).contiguous().reshape(seq_len, head_num * head_dim).T.type(torch.float32)


        tmp = out.shape[1]
        
        cov = out @ out.T
        self.out_cov  += cov
        self.nsamples += tmp

class WrappedGPT_after_rope_entire:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none", hidden_size=4096, device=None):
        self.layer = layer
        self.dev = device
        self.hidden_size = hidden_size

        self.out_cov = torch.zeros((hidden_size, hidden_size), device=self.dev)  
        self.out_mean = torch.zeros((hidden_size, 1), device=self.dev)
    
        self.out_avgcov = torch.zeros((hidden_size, hidden_size), device=self.dev)  
        self.out_avgmean = torch.zeros((hidden_size, 1), device=self.dev)

        self.out_tensor = None

        self.layer_id = layer_id
        self.layer_name = layer_name

        self.nsamples = 0
        self.avg_step = 0
        
    def cal_batch_cov(self, out):
                # out = out.squeeze(0).T.type(torch.float32)
        # inp = inp.squeeze(0).T.type(torch.float32)
        out_cov = torch.cov(out, correction=1)
        out_mean = torch.mean(out, dim=-1).unsqueeze(1)

        tmp = out.shape[1]

        mean = (self.nsamples * self.out_mean + tmp * out_mean) / (self.nsamples + tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples* (self.out_mean-mean)@(self.out_mean-mean).T + (out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples/(self.nsamples+tmp)*(out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        cov = ((self.nsamples - 1) * self.out_cov + (tmp - 1) * out_cov + self.nsamples * tmp / (
                    self.nsamples + tmp) * (out_mean - mean) @ (out_mean - mean).T) / (
                          self.nsamples + tmp - 1)  # unbias
        
        self.out_mean = mean
        self.out_cov = cov
        self.nsamples += tmp
        if self.nsamples % (128 * tmp) == 0:
            tmp = 1
            self.out_avgcov = (self.avg_step * self.out_avgcov + tmp * self.out_cov) / (self.avg_step + tmp)
            self.out_avgmean = (self.avg_step * self.out_avgmean + tmp * self.out_mean) / (self.avg_step + tmp)
            self.out_cov = torch.zeros((self.hidden_size, self.hidden_size), device=self.dev)
            self.out_mean = torch.zeros((self.hidden_size, 1), device=self.dev)
            self.avg_step += 1
            self.nsamples = 0

    def add_batch(self, inp, out):
        # import pdb;pdb.set_trace()
        if len(out.shape) == 3:
            out = out.unsqueeze(0)

        if len(inp.shape) == 3:
            inp = inp.unsqueeze(0)
        
        out = out.squeeze(0)
        head_num, seq_len, head_dim = out.shape
        out = out.transpose(0, 1).contiguous().reshape(seq_len, head_num * head_dim).T.type(torch.float32).unsqueeze(-1)

        if self.out_tensor is None:
            self.out_tensor = out
        else:
            self.out_tensor = torch.cat([self.out_tensor, out], dim=-1)
    
    def calculate(self):
        hidden_size, seq_len, bsz = self.out_tensor.shape
        self.out_tensor = self.out_tensor.transpose(1, 2)
        for i in range(seq_len):
            self.cal_batch_cov(self.out_tensor[:, :, i])
        

class WrappedGPT_new:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none", hidden_size=4096, device=None):
        self.layer = layer
        self.dev = device

        self.out_cov = torch.zeros((hidden_size, hidden_size), device=self.dev)  
        self.out_mean = torch.zeros((hidden_size, 1), device=self.dev)

        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        if len(out.shape) == 2:
            out = out.unsqueeze(0)

        out = out.squeeze(0).T.type(torch.float32)
        cov = torch.cov(out, correction=1)
        new_data = torch.mean(out, dim=-1).unsqueeze(1)
        tmp = out.shape[-1]
        # import pdb;pdb.set_trace()
        delta = new_data - self.out_mean
        self.out_mean += delta * tmp / (self.nsamples + tmp)
        delta2 = new_data - self.out_mean
        self.out_cov += self.nsamples * tmp / (self.nsamples + tmp) * delta2 @ delta2.T + (tmp - 1) * cov
        self.nsamples += tmp
    
    def get_cov(self):
        if self.nsamples < 2:
            return None
        return self.out_cov / (self.nsamples - 1)

class WrappedGPT_out:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.out = torch.zeros((self.rows, 2048), device=self.dev)
        self.inp = torch.zeros((2048, self.columns), device=self.dev)
        self.out_cov = torch.zeros((self.rows, self.rows), device=self.dev)  
        self.out_mean = torch.zeros((self.rows, 1), device=self.dev)

        self.out_avgcov = torch.zeros((self.rows, self.rows), device=self.dev)
        self.out_avgmean = torch.zeros((self.rows, 1), device=self.dev)
        self.nsamples = 0
        self.avg_step = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):

        if len(out.shape) == 2:
            out = out.unsqueeze(0)

        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)

        out = out.squeeze(0).T.type(torch.float32)

        out_cov = torch.cov(out, correction=1)
        out_mean = torch.mean(out, dim=-1).unsqueeze(1)


        tmp = out.shape[1]

        mean = (self.nsamples * self.out_mean + tmp * out_mean) / (self.nsamples + tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples* (self.out_mean-mean)@(self.out_mean-mean).T + (out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        # cov = (self.nsamples * self.out_cov + out_cov + self.nsamples/(self.nsamples+tmp)*(out_mean-mean)@(out_mean-mean).T)/(self.nsamples+tmp)
        cov = ((self.nsamples - 1) * self.out_cov + (tmp - 1) * out_cov + self.nsamples * tmp / (
                    self.nsamples + tmp) * (out_mean - mean) @ (out_mean - mean).T) / (
                          self.nsamples + tmp - 1)  # unbias
        self.out_mean = mean
        self.out_cov = cov
        self.nsamples += tmp
        if self.nsamples % (32 * tmp) == 0:
            tmp = 1
            self.out_avgcov = (self.avg_step * self.out_avgcov + tmp * self.out_cov) / (self.avg_step + tmp)
            self.out_avgmean = (self.avg_step * self.out_avgmean + tmp * self.out_mean) / (self.avg_step + tmp)
            self.out_cov = torch.zeros((self.rows, self.rows), device=self.dev)
            self.out_mean = torch.zeros((self.rows, 1), device=self.dev)
            self.avg_step += tmp
            self.nsamples = 0
