from typing import Any
from collections import OrderedDict

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only

from src.models.layers.sparse_fourier_linear import SparseFourierLinear
from src.models.modules.vision_common import AttentionSimple

import numpy as np

import torch
import torch.nn as nn

import os


class FreqSparsityMonitor(Callback):
    """Monitor the sparsity of frequencies of SFM layers.
    """

    def __init__(self, num_bins=20, save_dir=None):
        super().__init__()
        self.num_bins = num_bins
        self.save_dir = save_dir
        if self.save_dir is None:
            raise ValueError("save_dir must be specified.")

    def save_npy(self, inputs, name):
        with torch.no_grad():
            freqs = torch.fft.rfft(inputs.to(torch.float32), norm='ortho')
            freqs_mag = torch.abs(freqs).detach().cpu().numpy()
            if not os.path.exists(self.save_dir):
                os.makedirs(self.save_dir)
            
            if os.path.exists(os.path.join(self.save_dir, name+"_channel.npy")):
                npy_array = np.load(os.path.join(self.save_dir, name+"_channel.npy"))
                freqs_mag = np.concatenate([npy_array, freqs_mag], axis=0)
            np.save(os.path.join(self.save_dir, name+"_channel.npy"), freqs_mag)

            inputs_2d = inputs.permute(0,2,1)
            B,C,N = inputs_2d.size()
            inputs_2d = inputs_2d[:,:,1:].view(B,C,14,14)
            freqs = torch.fft.rfft2(inputs_2d.to(torch.float32), norm='ortho')
            freqs_mag = torch.abs(freqs).detach().cpu().numpy()
            if not os.path.exists(self.save_dir):
                os.makedirs(self.save_dir)
            
            if os.path.exists(os.path.join(self.save_dir, name+"_token.npy")):
                npy_array = np.load(os.path.join(self.save_dir, name+"_token.npy"))
                freqs_mag = np.concatenate([npy_array, freqs_mag], axis=0)
            np.save(os.path.join(self.save_dir, name+"_token.npy"), freqs_mag)

    def get_hook_fn(self, name):
        def hook(m, inputs, outputs):
            self.save_npy(inputs[0][:16,:,:], name+"_in")
            self.save_npy(outputs[:16,:,:], name+"_out")
        return hook





    @rank_zero_only
    def on_validation_epoch_start(self, trainer, pl_module):
        model = pl_module.model
        for mn, m in model.named_modules():
            if isinstance(m, (SparseFourierLinear, AttentionSimple)) or 'mlp.fc' in mn:
                print(mn)
                handle = m.register_forward_hook(self.get_hook_fn(mn))

    @rank_zero_only
    def on_test_epoch_start(self, trainer, pl_module):
        model = pl_module.model
        for mn, m in model.named_modules():
            if isinstance(m, (SparseFourierLinear, AttentionSimple)) or 'mlp.fc' in mn:
                print(mn)
                handle = m.register_forward_hook(self.get_hook_fn(mn))


