import torch


class GPM:
    def __init__(self, gpm_name_list: list, log_txt):
        self.gpm_dict = {name: [] for name in gpm_name_list}
        self.log_txt = log_txt

    def update_GPM(self, mat_list_dict: dict, threshold_dict: dict | float):
        if isinstance(threshold_dict, float):
            threshold = threshold_dict
            threshold_dict = {name: threshold for name in self.gpm_dict.keys()}
        if set(mat_list_dict.keys()) != set(self.gpm_dict.keys()) or set(threshold_dict.keys()) != set(self.gpm_dict.keys()):
            raise ValueError("Keys in mat_list_dict or threshold_dict do not match gpm_dict")
        for k in self.gpm_dict.keys():
            if threshold_dict[k] > 1e-8:
                self._update_GPM(mat_list_dict[k], self.gpm_dict[k], threshold_dict[k])

    def _update_GPM(self, mat_list, gpm_layers, threshold):
        if len(gpm_layers) == 0:
            for i in range(len(mat_list)):
                activation = mat_list[i].clone().detach().float()
                U, S, Vh = torch.linalg.svd(activation, full_matrices=False)
                sval_total = (S**2).sum()
                sval_ratio = (S**2)/sval_total
                r = torch.sum(torch.cumsum(sval_ratio, dim=0) < threshold).item()
                gpm_layers.append(U[:,0:max(r,1)])
        else:
            for i in range(len(mat_list)):
                activation = mat_list[i].clone().detach().float()
                U1, S1, Vh1 = torch.linalg.svd(activation, full_matrices=False)
                sval_total = (S1**2).sum()
                feature_tensor = gpm_layers[i].clone().detach().float()
                act_hat = activation - torch.matmul(torch.matmul(feature_tensor, feature_tensor.t()), activation)
                U, S, Vh = torch.linalg.svd(act_hat, full_matrices=False)
                sval_hat = (S**2).sum()
                sval_ratio = (S**2)/sval_total               
                accumulated_sval = (sval_total-sval_hat)/sval_total
                
                r = 0
                for ii in range(sval_ratio.shape[0]):
                    if accumulated_sval < threshold:
                        accumulated_sval += sval_ratio[ii]
                        r += 1
                    else:
                        break
                if r == 0:
                    print('Skip Updating GPM for layer: {}'.format(i+1)) 
                    continue
                Ui = torch.hstack((feature_tensor, U[:,0:r]))
                if Ui.shape[1] > Ui.shape[0]:
                    gpm_layers[i] = Ui[:,0:Ui.shape[0]]
                else:
                    gpm_layers[i] = Ui
        
        print('-'*40)
        print('Gradient Constraints Summary')
        for i in range(len(gpm_layers)):
            log = 'Layer {} : {}/{}'.format(i+1, gpm_layers[i].shape[1], gpm_layers[i].shape[0])
            print(log)
            self.log_txt(log)
        print('-'*40)