import torch
import os
import torch.nn as nn
import torch.optim as optim
from torch.fft import fft, ifft
import torch.nn.functional as F
import logging
from pathlib import Path

class ToDCT(nn.Module):
    def __init__(self, keep_ratio=0.4, savepth=None, model_name='cnn3'):
        super().__init__()
        self.keep_ratio = keep_ratio
        self.layer_blocks_norm = {}  
        self.layer_blocks_freq = {}  
        self.layer_presevered_freq = {}
        self.model_name = model_name
        self.savepth = Path(savepth) if savepth else None
     
    def forward(self, model):
        freq_loss = 0
        preveserved_freq_total = 0
        self.layer_blocks_norm.clear()  
        self.layer_blocks_freq.clear()  
        self.layer_presevered_freq.clear()

        for name, param in model.named_parameters():
            if param.requires_grad:
                f_l, p_f = self.frequency_regularization(param, name)
                freq_loss += f_l
                preveserved_freq_total += p_f


        return freq_loss, preveserved_freq_total

def reconstruct_weights(model, layer_blocks_freq, layer_blocks_norm, layer_freq,use_dct= True, use_ratio=False,select_n=None,mode='zero'):

    for name, param in model.named_parameters():

        
        if name in layer_blocks_freq:

            blocks_info = layer_blocks_freq[name]
            block_stats = layer_blocks_norm[name]

            h, w = param.shape
            reconstructed_weight = torch.zeros((h, w))
            
            freq_coeffs = blocks_info[0]['freq_coeffs']
            stats = block_stats[0]

            full_size = (h,w)
            padded_freq = torch.zeros(full_size)
            min_fre_h=min(freq_coeffs.shape[0],full_size[0])
            min_fre_w=min(freq_coeffs.shape[1],full_size[1])
            N,M=freq_coeffs.shape
            ratio_n=h/N
            ratio_m=w/M
            padded_freq[:min_fre_h, :min_fre_w] = freq_coeffs[:min_fre_h, :min_fre_w]

            if use_ratio:
                reconstructed_block = idct(idct(padded_freq).T).T*ratio_n*ratio_m
            else:
                reconstructed_block = idct(idct(padded_freq).T).T
            reconstructed_block = reconstructed_block * stats['std'] + stats['mean']
            reconstructed_weight.copy_(reconstructed_block[:h, :w])

            reconstructed_params[name] = reconstructed_weight


    return model