import numpy as np
import torch
import logging
from utils.wavelets import decomposition_fields

class Fields():

    def __init__(self, fname) -> None:
        """
        input: the fname of a .npz file
        """
        self.fields = np.load(fname)
        self.n, _ , self.L = self.fields['0'].shape
        self.J = np.log2(self.L).astype(int)

    def get_fields_at_scale(self, j):
        """ Returns: (n,4,L/2^j, L/2^j) array """
        assert j<=self.J, f'The coarsest scale is J={self.J} < j={j}'
        
        if j==0:
            f = np.expand_dims(self.fields[str(j)], axis=1)
            return torch.from_numpy(f)
        else:
            return torch.from_numpy(self.fields[str(j)])

    def get_details_at_scale(self, j):
        """ Returns: (n,3,L/2^j, L/2^j) array """
        return self.get_fields_at_scale(j)[:,1:,:,:]
    
    def get_low_freqs_at_scale(self,j):
        """ Returns: (n,1,L/2^j, L/2^j) array """
        return self.get_fields_at_scale(j)[:,0:1,:,:]


def decompose_and_save(fields, fname = None, folder = 'save/data/'):
    """
    input: 
        fields = (n,L,L) array
        fname = default None
        folder = default save/data/
    
    Decomposes the field, stores the field in a dict D with D['j'] = (n,4,L/2^j,L/2^j) array. If fname is not None, saves the dict under .npz format in folder/fname.

    return: dict D
    """
    L = fields.shape[-1] 
    J = np.log2(L).astype(int)
    decomposed_fields_at_scale = {}
    decomposed_fields_at_scale['0'] = fields

    for j in range(0,J):
        fields = decomposition_fields(fields)
        decomposed_fields_at_scale[str(j+1)] = fields
        fields = fields[:,0,:,:]
    
    if fname is not None:
        path = folder + fname
        np.savez(path, **decomposed_fields_at_scale)
        print('Decomposed fields saved at : ' + path + '.npz')
    
    return decomposed_fields_at_scale