import torch
import torch.nn as nn
import pandas as pd


class OptimModule(nn.Module):

    @torch.no_grad()
    def update(self, loss, optimizer, vis=False):
        optimizer.zero_grad()
        loss.backward()
        if vis:
            grad_info = self.summarize_grads_and_updates(optimizer)
        else:
            grad_info = None
        optimizer.step()
        return grad_info

    def register(self, name, tensor, lr=0.):

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)

    @torch.no_grad()
    def summarize_grads_and_updates(self, optimizer):
        summaries = []
        for name, param in self.named_parameters():
            if param.grad is not None:
                param_norm = param.norm().item()
                for group in optimizer.param_groups:
                    if any(p is param for p in group['params']):
                        lr = group['lr']
                        update_norm = (lr * param.grad).norm().item()
                        summaries.append({'Parameter': name,
                                          'Current Norm': param_norm,
                                          'Learning Rate': lr,
                                          'Update Norm': update_norm})
        return pd.DataFrame(summaries)

    @torch.no_grad()
    def test_rank(self, x):
        try:
            A = x @ x.T
            return torch.linalg.matrix_rank(A).item()
        except:
            return None

    @torch.no_grad()
    def rank_record(self, x):
        layer_list, rank_list = [], []

        output = x
        b, s, d = x.shape

        rank = self.test_rank(output.reshape(b, -1))
        if rank is not None:
            rank_list.append(rank)
            layer_list.append(0)

        for i, layer in enumerate(self.layers):
            output = layer(output)
            output = output[0] if isinstance(output, tuple) else output
            rank = self.test_rank(output.reshape(b, -1))
            if rank is not None:
                rank_list.append(rank)
                layer_list.append(i + 1)

        return output, layer_list, rank_list
