'''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.'''

from abc import ABC, abstractmethod
from functools import partial
import yaml
import torch
from torch.nn import functional as F
import torchaudio
import scipy
from scipy.optimize import fsolve
import numpy as np
from clap.hook import CLAP_Module

# =================
# Operation classes
# =================

class LinearOperator(ABC):
    def forward(self, data, **kwargs):
        # calculate A * X
        pass

    def transpose(self, data, **kwargs):
        # calculate A^T * X
        pass
    
    def ortho_project(self, data, **kwargs):
        # calculate (I - A^T * A)X
        return data - self.transpose(self.forward(data, **kwargs), **kwargs)

    def project(self, data, measurement, **kwargs):
        # calculate (I - A^T * A)Y - AX
        return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs)


class DenoiseOperator(LinearOperator):
    def __init__(self, device):
        self.device = device
    
    def forward(self, data):
        return data

    def transpose(self, data):
        return data
    
    def ortho_project(self, data):
        return data

    def project(self, data):
        return data

class InpaintingOperator(LinearOperator):
    '''This operator get pre-defined mask and return masked image.'''
    def __init__(self, device):
        self.device = device
    
    def forward(self, data, **kwargs):
        try:
            return data * kwargs.get('mask', None).to(self.device)
        except:
            raise ValueError("Require mask")
    
    def transpose(self, data, **kwargs):
        return data
    
    def ortho_project(self, data, **kwargs):
        return data - self.forward(data, **kwargs)


class NonLinearOperator(ABC):
    def forward(self, data, **kwargs):
        pass

    def project(self, data, measurement, **kwargs):
        return data + measurement - self.forward(data) 

def get_clip_value_from_SDR(seg, SDRdesired):
    """
        This function finds the corresponding clipping threshold for a given SDR
        Args:
        seg (Tensor): shape (T,) audio segment we want to clip
        SDRdesired (float) : Signal-to-Distortion Rateio (SDR) value
    """

    def find_clip_value(thresh, x, SDRtarget):
        xclipped = np.clip(x, -thresh, thresh)
        sdr = 20 * np.log10(np.linalg.norm(x) / (np.linalg.norm(x-xclipped) + 1e-7 ));
        return np.abs(sdr - SDRtarget)

    clip_value = fsolve(find_clip_value, 0.1, args=(seg.cpu().numpy(), SDRdesired))

    return clip_value[0]

class Declipping:
    # def __init__(self, device, clip_min_value, clip_max_value):
    def __init__(self, data, sdr, device):
        self.device = device
        clip_value = get_clip_value_from_SDR(data, sdr)
        self.clip_min_value = - clip_value
        self.clip_max_value = clip_value
    
    def __call__(self, data):
        data = torch.clip(data, min = self.clip_min_value, max = self.clip_max_value)
        return data

class BWE:
    # def __init__(self, device, clip_min_value, clip_max_value):
    def __init__(self, fc = 4000, order = 200, beta = 1, sr = 16000):
        self.order = order
        self.fc = fc
        self.beta = beta
        self.sr = sr
        self.filter = self.get_FIR_lowpass(order, fc, beta, sr)
    
    def get_FIR_lowpass(self, order,fc, beta, sr):
        """
            This function designs a FIR low pass filter using the window method. It uses scipy.signal
            Args:
                order(int): order of the filter
                fc (float): cutoff frequency
                sr (float): sampling rate
            Returns:
                B (Tensor): shape(1,1,order) FIR filter coefficients
        """
        B = scipy.signal.firwin(numtaps = order, cutoff = fc, width = beta, window = "kaiser", fs = sr)
        B = torch.FloatTensor(B)
        B = B.unsqueeze(0)
        B = B.unsqueeze(0)
        return B
    
    def apply_low_pass_firwin(self, y, filter):
        """
            Utility for applying a FIR filter, usinf pytorch conv1d
            Args;
                y (Tensor): shape (B,T) signal to filter
                filter (Tensor): shape (1,1,order) FIR filter coefficients
            Returns:
                y_lpf (Tensor): shape (B,T) filtered signal
        """
        #ii=2
        B = filter.to(y.device)
        y = y.unsqueeze(1)
        #weight=torch.nn.Parameter(B)
        
        y_lpf = torchaudio.functional.convolve(y, B, "same")
        y_lpf = y_lpf.squeeze(1) # some redundancy here, but its ok
        #y_lpf=y
        return y_lpf
    
    def __call__(self, data):
        data = self.apply_low_pass_firwin(data, self.filter)
        return data


class Intensity:
    def __init__(self, ctx_window, fir_order = 2, frame_length = 2048, hop_length = 256, device='cuda:0'):
        coeffs = scipy.signal.savgol_coeffs(ctx_window, fir_order, delta=1.0, pos=None, use='conv')
        self.fir_coeff = torch.from_numpy(coeffs).to(device).unsqueeze(0).to(torch.float32)
        self.device = device
    
    def rms_torch(self, y, frame_length: int = 1024, hop_length: int = 256, dtype=torch.float32) -> torch.Tensor:
        pad_size = int(frame_length // 2)
        # pad_size = 0
        padding = (pad_size, pad_size)
        y = F.pad(y, padding, mode="constant")
        
        x = y.unfold(-1, frame_length, hop_length)

        # Calculate power
        power = torch.mean(self.abs2_torch(x, dtype), dim=-1)
        # Calculate RMS
        rms_result = torch.sqrt(power)
        return rms_result
    
    def abs2_torch(self, x: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
        if torch.is_complex(x):
            y = x.abs().pow(2)
        else:
            y = x.pow(2)
        
        if dtype is not None:
            y = y.to(dtype)
        return y
    
    def frame_torch(self, x, frame_length: int, hop_length: int, axis:int = -1) -> torch.Tensor:

        # Convert x to a PyTorch tensor
        x = torch.as_tensor(x)
        
        x_shape_trimmed = list(x.shape)
        x_shape_trimmed[axis] -= frame_length - 1
        out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
        
        
        if axis < 0:
            axis += x.dim()
        x = x.transpose(axis, -1)
            
        num_frames = (x.shape[-1] - frame_length) // hop_length + 1

        # Prepare shape and stride for the new tensor
        new_shape = x.shape[:-1] + (num_frames, frame_length)
        new_strides = x.stride()[:-1] + (hop_length * x.stride(-1), x.stride(-1))        
        xw = torch.as_strided(x, size=new_shape, stride=new_strides)

        # Use as_strided to create the framed tensor
        return xw.transpose(-2, axis)
    
    def __call__(self, data):
        data = self.rms_torch(data)
        data = 20 * torch.log10(data + 1e-7)
        intensity_curve = torchaudio.functional.convolve(data, self.fir_coeff, "same")
        # intensity_curve = 20 * torch.log10(intensity_curve + 1e-7)
        return intensity_curve
class StyleGram:
    def __init__(self, ckpt=None, device = 'cuda:0', layer_index = 3):
        self.layer_index = layer_index
        self.clap = CLAP_Module(enable_fusion=False, device=device, amodel= 'HTSAT-tiny', tmodel='roberta')
        self.clap.load_ckpt(ckpt, model_id=1)
        self.clap.model.requires_grad_(False)
    
    def __call__(self, data):
        feature = self.clap.get_style_feature_from_data(data, use_tensor=True)
        feature = feature[self.layer_index][0, :, :]
        return torch.mm(feature.t(), feature)
        

# =============
# Noise classes
# =============



class Noise(ABC):
    def __call__(self, data):
        return self.forward(data)
    
    def forward(self, data):
        pass

class Clean(Noise):
    def forward(self, data):
        return data

class GaussianNoise(Noise):
    def __init__(self, sigma):
        self.sigma = sigma
    
    def forward(self, data):
        return data + torch.randn_like(data, device=data.device) * self.sigma
