import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

import math

def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

class LayerWrapper:
    def __init__(self, layer, input_batchsize=16):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]
        self.input = None
        self.H = torch.zeros((self.columns, self.columns), device=self.dev, dtype=torch.float32)
        self.M = torch.zeros((self.columns, 1), device=self.dev, dtype=torch.float32)
        self.nsamples = 0

    def update_hessian_mean(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()
        tmp = inp.shape[1]
        self.H *= self.nsamples / (self.nsamples + tmp)
        self.M *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        inp = inp.to(torch.float32)
        self.M += torch.sum(inp, dim=1, keepdim=True) / self.nsamples
        inp *= math.sqrt(1 / self.nsamples)
        self.H += inp.matmul(inp.t())

    def substract_mean_from_hessian(self):
        self.H -= self.M.matmul(self.M.t())

    def set_input(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
        self.input = inp
        self.nsamples = inp.shape[0]

    def reset_input(self):
        self.input = None
        self.nsamples = 0
        torch.cuda.empty_cache()

import torch
from torch.utils.data import Dataset

class ActivationDataset(Dataset):
    def __init__(self, name_list, device):
        self.input_dict = {}
        self.input_normalized_sorted_dict = {}
        self.input_mean_dict = {}
        self.input_hessian_dict = {}
        self.name_list = name_list
        self.device = device
        self.length = 0

        for name in name_list:
            self.input_dict[name] = []
            self.input_mean_dict[name] = []
            self.input_hessian_dict[name] = []

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        sample = {}
        for name in self.name_list:
            sample[name] = self.input_dict[name][idx]
        return sample
    
    def __del__(self):
        for name in self.name_list:
            if name in self.input_dict and isinstance(self.input_dict[name], torch.Tensor):
                del self.input_dict[name]
                del self.input_mean_dict[name]
                del self.input_hessian_dict[name]
        torch.cuda.empty_cache()

    def update(self, new_input_dict):
        for name in self.name_list:
            tmp_input = new_input_dict[name].to(self.device)
            l = tmp_input.shape[0]
            self.input_dict[name].append(tmp_input)
            self.input_mean_dict[name].append(torch.mean(tmp_input, dim=0, keepdim=True).to(self.device))
            tmp_input = tmp_input.to(torch.float32)
            tmp_input *= math.sqrt(1 / l)
            self.input_hessian_dict[name].append((tmp_input.t() @ tmp_input).to(self.device))
        self.length += new_input_dict[self.name_list[0]].size(0)

    def concat_lists_to_tensors(self):
        with torch.no_grad():
            for name in self.name_list:
                self.input_dict[name] = torch.cat(self.input_dict[name], dim=0)
                self.input_mean_dict[name] = torch.cat(self.input_mean_dict[name], dim=0)
                self.input_mean_dict[name] = torch.mean(self.input_mean_dict[name], dim=0, keepdim=True)
                self.input_hessian_dict[name] = torch.stack(self.input_hessian_dict[name], dim=0)
                self.input_hessian_dict[name] = torch.mean(self.input_hessian_dict[name], dim=0)

    def normalize_and_sort(self):
        with torch.no_grad():
            for name in self.name_list:
                torch.cuda.synchronize()
                input_shape = self.input_dict[name].shape
                input_dtype = self.input_dict[name].dtype
                x = self.input_dict[name].detach().view(-1, input_shape[-1]).to(torch.float32)
                x_square = x.mul_(x)
                x_square_sum = x_square.sum(dim=-1, keepdim=True)
                x_square_normalized = x_square.div_(x_square_sum)
                x_square_normalized = x_square_normalized.view(-1)
                sorted_value, sorted_index = torch.sort(x_square_normalized)
                self.input_normalized_sorted_dict[name] = sorted_value.to(input_dtype)
                del sorted_value, sorted_index, x_square_normalized
                import gc
                gc.collect()
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            
class Sparsifier(nn.Module):
    def __init__(self, threshold=0, zero_ratio=0, weight_norm=None, bias=None, spectrum_threshold=None):
        super(Sparsifier, self).__init__()
        self.threshold = threshold
        self.zero_ratio = zero_ratio
        self.weight_norm = weight_norm
        self.bias = bias
        self.spectrum_threshold = spectrum_threshold

    def threshold_sparsify(self, x):
        if self.bias is not None:
            x += self.bias.unsqueeze(0)
        return torch.where(x < self.threshold, torch.tensor(0.0, device=x.device), x)
    
    def spectrum_sparsify(self, x):
        if self.bias is not None:
            x += self.bias.unsqueeze(0)
        x_shape = x.shape
        x = x.view(-1, x_shape[-1])
        x_square = torch.pow(x.to(torch.float32), 2)
        x_square_sum = torch.sum(x_square, dim=-1, keepdim=True)
        x_normalized = x_square / x_square_sum
        indices = x_normalized >= self.spectrum_threshold
        x = (x * indices).view(x_shape)
        return x

    def spectrum_topk_sparsify(self, x):
        if self.bias is not None:
            x += self.bias.unsqueeze(0)
        x_shape = x.shape
        x = x.view(-1, x_shape[-1])
        x_square = x * x
        x_square_sum = torch.sum(x_square, dim=-1, keepdim=True)
        x_normalized = x_square / x_square_sum
        _, indices = torch.topk(x_normalized, k, dim=-1)
        mask = torch.zeros_like(x, device=x.device)
        mask.scatter_(-1, indices, 1)
        return x * mask

    def topk_sparsify(self, x):
        if self.bias is not None:
            x += self.bias.unsqueeze(0)
        k = int(x.size(-1) * (1 - self.zero_ratio))
        # if self.weight_norm is not None:
        #     _, indices = torch.topk(torch.abs(x) * self.weight_norm, k, dim=-1)
        # else:
        #     _, indices = torch.topk(torch.abs(x), k, dim=-1)
        _, indices = torch.topk(torch.abs(x), k, dim=-1)
        mask = torch.zeros_like(x, device=x.device)
        mask.scatter_(-1, indices, 1)
        return x * mask

    def block_topk_sparsify(self, x):
        if self.bias is not None:
            x += self.bias.unsqueeze(0)
        x_shape = x.shape
        k = int(32 * (1 - self.zero_ratio))
        x = x.view(-1, x_shape[-1] // 32, 32)
        _, indices = torch.topk(torch.abs(x), k, dim=-1)
        mask = torch.zeros_like(x, device=x.device)
        mask.scatter_(-1, indices, 1)
        x = x * mask
        x = x.view(x_shape)
        return x


class SparseLinear(nn.Module):
    def __init__(self, linear=None, sparsifier=None):
        super().__init__()
        self.linear = linear
        self.sparsifier = sparsifier
    
    def forward(self, x):
        if self.sparsifier is not None:
            # x = self.sparsifier.threshold_sparsify(x)
            x = self.sparsifier.spectrum_sparsify(x)
            # x = self.sparsifier.spectrum_topk_sparsify(x)
            # x = self.sparsifier.topk_sparsify(x)
            # x = self.sparsifier.block_topk_sparsify(x)
    
        x = self.linear.forward(x)
        return x

def convert_to_sparse_linear(linear_layer, sparsifier=None):
    if sparsifier==None:
        sparsifier = Sparsifier()
    sparse_linear_layer = SparseLinear(linear=linear_layer, sparsifier=sparsifier)
    return sparse_linear_layer

def create_max_variance_permutation_index(vector_size, subvector_size):
    assert vector_size % subvector_size == 0, "vector_size must be divisible by subvector_size"
    num_subvectors = vector_size // subvector_size
    indices = torch.zeros(vector_size, dtype=torch.long)
    for i in range(subvector_size):
        for j in range(num_subvectors):
            if j % 2 == 0:
                index = i * num_subvectors + j
            else:
                index = (subvector_size - i - 1) * num_subvectors + j
            pos = j * subvector_size + i
            indices[pos] = index
    return indices

def create_max_variance_permutation_matrix(vector_size, subvector_size):
    indices = create_max_variance_permutation_index(vector_size, subvector_size)
    P = torch.eye(vector_size)[indices]
    return P



class RotatorOptimizer(torch.nn.Module):

    def __init__(self, weight_dict_list, r_dim, num_key_value_heads, head_dim, device, positive=True, hessian_dict_list=None, num_piece = 1, dtype=torch.bfloat16, with_weight=False):

        super().__init__()

        self.weight_dict_list = weight_dict_list
        self.num_piece = num_piece
        self.r_dim = r_dim
        self.device = device
        self.A_dim = self.r_dim // self.num_piece
        self.positive = positive
        self.hessian_dict_list = hessian_dict_list
        self.num_layer = len(weight_dict_list)
        self.B_list_list = []
        self.num_key_value_heads = num_key_value_heads
        self.head_dim = head_dim
        self.dtype = dtype
        self.with_weight = with_weight

        self.A_list = [torch.nn.Parameter(torch.eye(self.A_dim, device=device, dtype=torch.float32)) for i in range(self.num_piece)]  # 可学习参数
        for i in range(self.num_layer):
            B_list = []
            for j in range(self.num_key_value_heads):
                # hessian = torch.zeros((self.head_dim, self.head_dim), dtype=torch.float32, device=self.device)
                # repeat_num = self.A_dim // self.head_dim // self.num_key_value_heads
                # for k in range(repeat_num):
                #     start = (j * repeat_num + k) * self.head_dim
                #     end = (j * repeat_num + k) * self.head_dim + self.head_dim
                #     hessian += self.hessian_dict_list[i]["self_attn.o_proj"][start:end, start:end]
                # eigenvalues, eigenvectors = torch.linalg.eigh(hessian)
                # B_list.append(torch.nn.Parameter(eigenvectors))
                B_list.append(torch.nn.Parameter(torch.eye(self.head_dim, dtype=torch.float32, device=self.device)))
            self.B_list_list.append(B_list)
        
        for idx in range(self.num_layer):
            for name in self.weight_dict_list[idx]:
                tmp_weight = self.weight_dict_list[idx][name].weight.detach().to(self.device).to(self.dtype)
                self.weight_dict_list[idx][name] = tmp_weight.T @ tmp_weight
                self.weight_dict_list[idx][name].requires_grad_(False)
            if hessian_dict_list != None:
                for name in self.hessian_dict_list[idx]:
                    self.hessian_dict_list[idx][name] = self.hessian_dict_list[idx][name].detach().to(self.device).to(self.dtype)
                    self.hessian_dict_list[idx][name].requires_grad_(False)

    def compute_cov_norm_squared(self):
        self.h1_cov_norm_squared = torch.tensor(0.0, dtype=self.dtype, device=self.device)
        self.h2_cov_norm_squared = torch.tensor(0.0, dtype=self.dtype, device=self.device)
        self.h3_cov_norm_squared = torch.tensor(0.0, dtype=self.dtype, device=self.device)
        for i in range(self.num_layer):
            self.h1_cov_norm_squared += torch.sum(torch.pow(self.hessian_dict_list[i]["self_attn.q_proj"], 2))
            self.h2_cov_norm_squared += torch.sum(torch.pow(self.hessian_dict_list[i]["self_attn.o_proj"], 2))
            self.h3_cov_norm_squared += torch.sum(torch.pow(self.hessian_dict_list[i]["mlp.up_proj"], 2))

    def initialize_with_pca(self):
        for i in range(self.num_layer):
            # R1
            R1_cov = torch.zeros((self.A_dim, self.A_dim), dtype=torch.float32, device=self.device)
            for name in ["self_attn.q_proj", "mlp.up_proj"]:
                R1_cov += self.hessian_dict_list[i][name]
            eigenvalues, eigenvectors = torch.linalg.eigh(R1_cov)
            self.A_list[i].data = eigenvectors
            # R2
            for j in range(self.num_key_value_heads):
                hessian = torch.zeros((self.head_dim, self.head_dim), dtype=torch.float32, device=self.device)
                repeat_num = self.A_dim // self.head_dim // self.num_key_value_heads
                for k in range(repeat_num):
                    start = (j * repeat_num + k) * self.head_dim
                    end = (j * repeat_num + k) * self.head_dim + self.head_dim
                    hessian += self.hessian_dict_list[i]["self_attn.o_proj"][start:end, start:end]
                eigenvalues, eigenvectors = torch.linalg.eigh(hessian)
                self.B_list_list[i][j].data = eigenvectors

    def parameters(self, recurse=True):
        res = []
        for l in self.A_list:
            res.append(l)
        for l in self.B_list_list:
            res += l
        return res

    def parameters_R1(self, recurse=True):
        res = []
        for l in self.A_list:
            res.append(l)
        return res

    def parameters_R2(self, recurse=True):
        res = []
        for l in self.B_list_list:
            res += l
        return res

    def get_orthogonal_matrix(self):
        Q = torch.block_diag(*[torch.linalg.qr(self.A_list[i])[0] for i in range(self.num_piece)]).to(dtype=self.dtype)
        return Q
    
    def get_orthogonal_matrix_R2_list_list(self):
        R2_list_list = []
        for i in range(self.num_layer):
            R2_list = []
            for j in range(self.num_key_value_heads):
                R2_list.append(torch.linalg.qr(self.B_list_list[i][j])[0].to(dtype=self.dtype))
            R2_list_list.append(R2_list)
        return R2_list_list

    def get_R1_list(self):
        return [torch.linalg.qr(self.A_list[i])[0] for i in range(self.num_piece)]

    def get_R2_list_list(self):
        return self.get_orthogonal_matrix_R2_list_list()
    def compute_salience_RWX(self, weight, input, R):
        return compute_salience_RWX_contextual(weight, input, R)

    def compute_salience_WR_1RX(self, weight, hessian, R):
        rotated_hessian = R.T @ hessian @ R
        # indices = torch.arange(weight.shape[1])
        # rotated_hessian[indices, indices] = 0
        # loss = torch.sum(torch.pow(rotated_hessian, 2))
        loss = -torch.sum(torch.pow(torch.diag(rotated_hessian), 2))
        return loss

    def compute_salience_R2WR_1RX(self, weight, input, R, R2_list):
        return compute_salience_R2WR_1RX_contextual(weight, input, R, R2_list)

    def compute_loss_XR2_R2TW_v_proj(self, v_proj_weight, hessian, R2_list):
        r2_list = []
        for r2 in R2_list:
            for i in range(v_proj_weight.shape[1] // r2.shape[0] // len(R2_list)):
                r2_list.append(r2)
        R2 = torch.block_diag(*r2_list)
        rotated_hessian = R2.T @ hessian @ R2
        indices = torch.arange(v_proj_weight.shape[1])
        
        # rotated_weight = R2.T @ v_proj_weight.T @ v_proj_weight @ R2
        # rotated_cov = rotated_hessian * rotated_weight
        rotated_cov = rotated_hessian
        # rotated_cov[indices, indices] = 0
        # loss = torch.sum(torch.pow(rotated_cov, 2))
        loss = -torch.sum(torch.pow(torch.diag(rotated_hessian), 2))
        return loss

    def forward_h1_h3(self):
        R = self.get_orthogonal_matrix()
        if self.with_weight == False:
            h1_h3_list = ["self_attn.q_proj", "mlp.up_proj"]
            loss = None
            for idx in range(self.num_layer):
                for name in h1_h3_list:
                    weight = self.weight_dict_list[idx][name]
                    hessian = self.hessian_dict_list[idx][name]
                    layer_loss = self.compute_salience_WR_1RX(weight, hessian, R)
                    if loss == None:
                        loss = layer_loss
                    else:
                        loss += layer_loss
        else:
            loss = None
            for idx in range(self.num_layer):
                h1_hessian = self.hessian_dict_list[idx]["self_attn.q_proj"]
                rotated_h1_hessian = R.T @ h1_hessian @ R
                indices = torch.arange(rotated_h1_hessian.shape[0])
                rotated_h1_hessian[indices, indices] = 0

                for name in ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]:
                    rotated_weight = R.T @ self.weight_dict_list[idx][name] @ R
                    rotated_conv = rotated_h1_hessian * rotated_weight
                    layer_loss = torch.sum(torch.pow(rotated_conv, 2))
                    if loss == None:
                        loss = layer_loss
                    else:
                        loss += layer_loss

                h3_hessian = self.hessian_dict_list[idx]["mlp.up_proj"]
                rotated_h3_hessian = R.T @ h3_hessian @ R
                rotated_h3_hessian[indices, indices] = 0

                for name in ["mlp.up_proj", "mlp.gate_proj"]:
                    rotated_weight = R.T @ self.weight_dict_list[idx][name] @ R
                    rotated_conv = rotated_h1_hessian * rotated_weight
                    layer_loss = torch.sum(torch.pow(rotated_conv, 2))
                    if loss == None:
                        loss = layer_loss
                    else:
                        loss += layer_loss
        return loss

    def forward_h1(self):
        R = self.get_orthogonal_matrix()
        h1_list = ["self_attn.q_proj"]
        loss = None
        for idx in range(self.num_layer):
            for name in h1_list:
                weight = self.weight_dict_list[idx][name]
                hessian = self.hessian_dict_list[idx][name]
                layer_loss = self.compute_salience_WR_1RX(weight, hessian, R)
                if loss == None:
                    loss = layer_loss
                else:
                    loss += layer_loss
        return loss

    def forward_h2(self):
        
        R2_list_list = self.get_orthogonal_matrix_R2_list_list()
        for idx in range(len(R2_list_list)):
            for j in range(len(R2_list_list[idx])):
                R2_list_list[idx][j] = R2_list_list[idx][j]
        loss = None

        h2_list = ["self_attn.o_proj"]

        for idx in range(self.num_layer):
            for name in h2_list:
                weight = self.weight_dict_list[idx][name]
                hessian = self.hessian_dict_list[idx][name]
                layer_loss = self.compute_loss_XR2_R2TW_v_proj(weight, hessian, R2_list_list[idx])

                if loss == None:
                    loss = layer_loss
                else:
                    loss += layer_loss
        return loss

    def forward_h3(self):
        R = self.get_orthogonal_matrix()
        h3_list = ["mlp.up_proj"]
        loss = None
        for idx in range(self.num_layer):
            for name in h3_list:
                weight = self.weight_dict_list[idx][name]
                hessian = self.hessian_dict_list[idx][name]
                layer_loss = self.compute_salience_WR_1RX(weight, hessian, R)
                if loss == None:
                    loss = layer_loss
                else:
                    loss += layer_loss
        return loss