from dataloaders.lpsda import LPSDALoader1D
from dataloaders.analytic import Advection1D, Diffusion1D
from dataloaders.pdebench.pdebench_loader import FNODatasetSingle
import torch

def extend(a, n, axis = 0):
    zeros_shape = list(a.shape)
    zeros_shape[axis] = n - a.shape[axis]
    return torch.cat([a, torch.zeros(zeros_shape)], axis = axis)

def l2_inner(a, b, axis = 0):
    # a: (Sa,)
    # b: (Sb)
    # pad a with zeros in first dim to get to Sb
    a = extend(a, b.shape[axis], axis = axis)
    return torch.sum(a * b.conj(), axis = axis)

def l2_norm(a, axis=0):
    return l2_inner(a,a, axis=axis).sqrt()

def compute_gamma(u_hr, u_lr, axis = 0):
    hr = u_hr.shape[axis]
    lr = u_lr.shape[axis]
    u_hr = torch.Tensor(u_hr)
    u_lr = torch.Tensor(u_lr)
    # print(f"res: {hr}, lowres: {lr}")

    yf = torch.fft.rfft(1/hr * u_hr, axis = axis)
    # print(yf.shape)
    yf_lowres = torch.fft.rfft(1/lr * u_lr, axis = axis)
    # print(yf_lowres.shape)
    return l2_inner(yf_lowres, yf, axis = axis)/l2_norm(yf_lowres, axis = axis)**2

def compute_beta(u_hr, u_lr, axis = 0):
    u_lr = extend(u_lr, u_hr.shape[axis], axis = axis)
    gamma = compute_gamma(u_hr, u_lr, axis = axis)
    # extend dimensions of gamma to match u_hr
    gamma = gamma.unsqueeze(axis).expand_as(u_hr)
    return l2_norm(u_hr - gamma * u_lr, axis = axis).abs()**2

def compute_eta(u_hr, u_lr, axis = 0):
    return torch.abs(compute_gamma(u_hr, u_lr, axis = axis) -1.0)

def compute_normalized_beta(u_hr, u_lr, axis=0):
    return compute_beta(u_hr, u_lr, axis = axis)/l2_norm(u_lr, axis = axis).abs()**2

def compute_eta_by_dataset(dataset):
    if isinstance(dataset, LPSDALoader1D) and dataset.chunk_train is False:
        u_lowr = dataset.u.squeeze(-1)
        u_hr = dataset.u_hr

    elif isinstance(dataset, Advection1D) or isinstance(dataset, Diffusion1D):
        u_lowr = torch.Tensor(dataset.u.squeeze(-1))
        u_hr = torch.Tensor(dataset.u_hr.squeeze(-1))
    elif isinstance(dataset, FNODatasetSingle):
        u_lowr = dataset.data.squeeze(-1)
        u_hr = dataset.data_hr.squeeze(-1)
    else: 
        print("Eta Computation Failed: Dataset is not LPSDALoader1D or chunk_train is True")
        return {}
    
    eta = compute_eta(u_hr, u_lowr, axis=1) # (N, T)
    beta = compute_beta(u_hr, u_lowr, axis=1) # (N, T)
    beta_normalized = compute_normalized_beta(u_hr, u_lowr, axis=1) #(N, T)
    r = {}
    for t in range(eta.shape[1]):
        r.update({f"Eta/eta{t}" : eta[:,t].mean().item(),
                 f"Eta/eta(1-eta){t}" : (eta[:,t] * (1-eta[:,t])).mean().item(),
                f"Beta/beta{t}" : beta[:,t].mean().item(),
                f"Normalized Beta/beta{t}" : beta_normalized[:,t].mean().item()})
    r.update({"Eta/,mean_eta" : eta.mean().item(),
                "Eta/mean_eta(1-eta)" : (eta * (1-eta)).mean().item(),
                "Beta/mean_beta" : beta.mean().item(),
                "Normalized Beta/mean_normalized_beta" : beta_normalized.mean().item()})
    return r
             
