import torch
import torch.nn as nn

class ModuleSI(nn.Module):

    def in_si_group(self, param):
        return False
    
    @torch.no_grad()
    def si_norm2(self) -> float:
        return sum([(param ** 2).sum().float().item() for param in self.si_parameters()])

    @torch.no_grad()
    def si_norm(self) -> float:
        return self.si_norm2() ** 0.5
    
    @torch.no_grad()
    def log_dict(self, log=None) -> dict:
        if log is None:
            log = {}
        si_norm2 = 0.
        nsi_norm2 = 0.
        si_gnorm2 = 0.
        nsi_gnorm2 = 0.
        for name, param in self.named_parameters():
            norm2 = (param.data ** 2).sum().float().item()
            gnorm2 = (param.grad ** 2).sum().float().item() if param.grad is not None else None

            if len(param.shape) == 1 and param.shape[0] <= 10:
                log.update({f'param/{name}/val-{k}': param[k] for k in range(param.shape[0])})
            
            if self.in_si_group(param):
                si_norm2 += norm2
                log[f'normp/si/{name}'] = norm2 ** 0.5
                if gnorm2 is not None:
                    si_gnorm2 += gnorm2
                    log[f'gnormp/si/{name}'] = gnorm2 ** 0.5
            else:
                nsi_norm2 += norm2
                log[f'normp/nsi/{name}'] = norm2 ** 0.5
                if gnorm2 is not None:
                    nsi_gnorm2 += gnorm2
                    log[f'gnormp/nsi/{name}'] = gnorm2 ** 0.5
        
        log[f'norm/si'] = si_norm2 ** 0.5
        log[f'norm/nsi'] = nsi_norm2 ** 0.5
        log[f'gnorm/si'] = si_gnorm2 ** 0.5
        log[f'sgnorm/si'] = (si_gnorm2 * si_norm2) ** 0.5
        log[f'gnorm/nsi'] = nsi_gnorm2 ** 0.5
        return log
    
    
    @torch.no_grad()
    def log_eigg(self, log, seigv, hess_grad_prod) -> dict:
        assert 'norm/si' in log
        assert 'seigs/0' in log

        for k in range(seigv.shape[1]):
            eigg = 0.
            
            v = torch.from_numpy(seigv[:, k]).cuda()
            assert abs(torch.linalg.norm(v) - 1) < 1e-5
            ptr = 0
            for p in self.si_parameters():
                n = p.numel()
                eigg += p.grad.view(-1) @ v[ptr : ptr + n]
                ptr += n
            
            log[f'seigg/{k}'] = abs(eigg)
            log[f'seigr/{k}'] = abs(eigg) / log['gnorm/si']
        
        diff = 0.
        gHg = 0.
        for i, p in enumerate(self.si_parameters()):
            diff += (p.grad - hess_grad_prod[i] / log['seigs/0']).square().sum()
            gHg += p.grad.view(-1) @ hess_grad_prod[i].view(-1)
        diff = diff ** 0.5

        log['top-align/abs_error'] = diff
        log['top-align/rel_error'] = diff / log['gnorm/si']
        log['top-align/abs_gHg'] = gHg
        log['top-align/rel_gHg'] = gHg / log['gnorm/si'] ** 2 / log['seigs/0']
        
        return log


    def si_parameters(self):
        for param in self.parameters():
            if self.in_si_group(param):
                yield param
    
    def nsi_parameters(self):
        for param in self.parameters():
            if not self.in_si_group(param) and param.requires_grad:
                yield param
    
    def trainable_parameters(self):
        for param in self.parameters():
            if param.requires_grad:
                yield param
