import torch
from torch import distributions as D
import matplotlib.pyplot as plt


def standardize(tensor: torch.Tensor, target_std: float = 1., dim=1, moments=True
                ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    mean = tensor.mean(dim, keepdim=True)
    std = tensor.std(dim, keepdim=True) + 1e-12
    std /= target_std
    out = (tensor - mean) / std
    if moments:
        out = (out, mean, std)
    return out

def standardize(tensor: torch.Tensor, target_denom: float = 1., dim=1, moments=True
                ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    mean = tensor.mean(dim, keepdim=True)
    std = torch.ones_like(mean) * target_denom
    out = (tensor - mean) / std
    if moments:
        out = (out, mean, std)
    return out

def covariance(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    # both of shape (b,n) or second one of (b,n,d)
    n = a.shape[1]
    a_ = a - a.mean(1, keepdim=True)
    b_ = b - b.mean(1, keepdim=True)
    if a.dim() == 2 and b.dim() == 3:
        a_ = a_.unsqueeze(-1)
    cov = (a_ * b_).sum(1) / (n-1)
    return cov

def batch_covariance(data: torch.Tensor) -> torch.Tensor:
    # data (b, n, d)
    n = data.shape[1]
    mean = data.mean(1, keepdim=True)
    centered = data - mean
    cov = (centered.transpose(1,2) @ centered) / (n-1)
    return cov

def RMSE(a: torch.Tensor, b: torch.Tensor):
    return (a - b).square().mean().sqrt()

    
# -----------------------------------------------------------------------------
# Fixed Effects

def datasetFfx(b: int, n: int, d: int):
    # hyperparams
    sigma_eps = D.HalfNormal(1.).sample((b,))
    sigma_beta = D.HalfNormal(3.).sample((b,d+1))
    sigma_X = D.HalfNormal(1.).sample((b,d))
    
    # X
    intercept = torch.ones((b,n,1))
    x = torch.randn((b,n,d)) * sigma_X.view(b,1,d) + torch.randn((b,1,d))
    X = torch.cat([intercept, x], dim=-1)
    
    # ffx
    ffx = torch.randn((b,d+1)) * sigma_beta
    
    # eps 
    eps = torch.randn((b,n))
    eps = standardize(eps, moments=False) * sigma_eps.view(b,1)
    
    # y
    y_hat = torch.einsum('bnd,bd->bn', X, ffx)
    y = y_hat + eps
    
    return {'X': X, 'x': x, 'y': y, 'y_hat': y_hat,
            'ffx': ffx, 'sigma_eps': sigma_eps}


def standardizeFfx(ds: dict):
    # unpack
    y, X, ffx, sigma_eps = ds['y'], ds['X'], ds['ffx'], ds['sigma_eps']
    
    # data
    X_, mean_X, std_X = standardize(X)
    y_, mean_y, std_y = standardize(y, 5.)
    
    # parameters
    sigma_eps_ = sigma_eps / std_y.view(b)
    ffx_ = ffx * std_X.view(b, d+1) / std_y
    ffx_[:, 0] = 0.
    y_hat_ = torch.einsum('bnd,bd->bn', X_, ffx_)
    eps_ = y_ - y_hat_
    
    # ffx recovered
    ffx_r = ffx_ / std_X.view(b, d+1) * std_y
    ffx_r[:, 0] = mean_y.view(b) - (mean_X.view(b, d+1) * ffx_r).sum(1)
    y_hat_r = torch.einsum('bnd,bd->bn', X, ffx_r)
    
    # check intercept
    rmse = RMSE(ffx_r[:, 0], ffx[:, 0])
    print(f'RMSE (intercept): {rmse:.3f}')
    
    rmse = RMSE(ffx_r[:, 1], ffx[:, 1])
    print(f'RMSE (slope): {rmse:.3f}')
    
    # check sigma
    rmse = RMSE(sigma_eps_, eps_.std(1))
    print(f'RMSE (sigma, direct): {rmse:.3f}')
    
    # analytical solution
    eps_var = y_.std(1) + y_hat_.var(1) - 2 * covariance(y_, y_hat_) 
    rmse = RMSE(sigma_eps_.squeeze(), eps_var.sqrt())
    print(f'RMSE (sigma, analytical 1): {rmse:.3f}')     
    
    # analytical without y_hat_
    y_hat_var = (ffx_.unsqueeze(2) * batch_covariance(X_) * ffx_.unsqueeze(1)).sum((-1,-2))
    cov_y_Xb = (covariance(y_, X_) * ffx_).sum(1)
    eps_var = y_.std(1) + y_hat_var - 2*cov_y_Xb
    rmse = RMSE(sigma_eps_.squeeze(), eps_var.sqrt())
    print(f'RMSE (sigma, analytical 2): {rmse:.3f}')        
    
    # export
    out = {
        'ffx_': ffx_, 'ffx_r': ffx_r,
        'y_hat_': y_hat_, 'y_hat_r': y_hat_r,
        'X_': X_, 'y_': y_,
        'mean_X': mean_X, 'mean_y': mean_y,
        'std_X': std_X, 'std_y': std_y
        }
    ds.update(out)


# def fitFFX(y: torch.Tensor, X: torch.Tensor) -> torch.Tensor:
#     coefs = []
#     intercepts = []
#     for i in range(b):
#         # Convert torch tensors to numpy arrays for sklearn
#         x_np = X[i].numpy()
#         y_np = y[i].numpy()
#         model = LinearRegression()
#         model.fit(x_np, y_np)
#         coefs.append(torch.from_numpy(model.coef_))
#         intercepts.append(torch.from_numpy(model.intercept_))
#     coefs = torch.stack(coefs).squeeze(1)
#     intercepts = torch.stack(intercepts)
#     ffx_hat = torch.cat([intercepts, coefs], dim=1)
#     return ffx_hat
        
def compareFFX(ds: dict):
    # Original
    X, y = ds['X'], ds['y']
    ffx_hat = torch.linalg.lstsq(X, y).solution
    rmse = (ffx_hat - ds['ffx']).square().mean().sqrt()
    print(f'RMSE (estimated vs. true in original space): {rmse:.3f}')    
    
    # Standardized
    X, y = ds['X_'], ds['y_']
    X[..., 0] = 1.
    ffx_hat_ = torch.linalg.lstsq(X, y).solution
    ffx_hat_[..., 0] = 0.
    rmse = RMSE(ffx_hat_, ds['ffx_'])
    print(f'RMSE (estimated vs. true in standardized space): {rmse:.3f}')    
    
    # Recovered
    ffx_hat_r = ffx_hat_ / ds['std_X'].view(b, d+1) * ds['std_y']
    ffx_hat_r[:, 0] = ds['mean_y'].view(b) - (ds['mean_X'].view(b, d+1) * ffx_hat_r).sum(1)
    rmse = RMSE(ffx_hat_r, ffx_hat)
    print(f'RMSE (recovered vs. original estimates): {rmse:.3f}')
    # rmse = (ffx_hat_r - ds['ffx']).square().mean().sqrt()
    # print(f'RMSE (recovered vs. true): {rmse:.3f}')

# -----------------------------------------------------------------------------
# Mixed Effects

def datasetMfx(b: int, m: int, n: int, d: int, q: int):
    # hyperparams
    sigma_eps = D.HalfNormal(1.).sample((b,))
    sigma_rfx = D.HalfNormal(10.).sample((b,q))
    sigma_beta = D.HalfNormal(3.).sample((b,d+1))
    sigma_X = D.HalfNormal(1.).sample((b,d))
    
    # X, Z
    intercept = torch.ones((b,m,n,1))
    x = torch.randn((b,m,n,d)) * sigma_X.view(b,1,1,d) + torch.randn((b,1,1,d))
    X = torch.cat([intercept, x], dim=-1)
    Z = X[..., :q]
    
    # ffx, rfx
    ffx = torch.randn((b,d+1)) * sigma_beta
    rfx = torch.randn((b,m,q))
    rfx -= rfx.mean(1, keepdim=True)
    # rfx, _ = torch.linalg.qr(rfx)
    rfx /= rfx.std(1, keepdim=True)
    rfx *= sigma_rfx.view(b,1,q)
    
    # eps
    eps = torch.randn((b,m,n))
    eps = standardize(eps, dim=(1,2), moments=False) * sigma_eps.view(b,1,1)
    
    # y
    mu_g = torch.einsum('bmnd,bd->bmn', X, ffx)
    mu_l = torch.einsum('bmnq,bmq->bmn', Z, rfx)
    y_hat = mu_g + mu_l
    y = y_hat + eps
    
    return {'X': X, 'Z': Z, 'y': y, 'y_hat': y_hat,
            'ffx': ffx, 'rfx': rfx,
            'sigma_eps': sigma_eps, 'sigma_rfx': sigma_rfx}


def standardizeMfx(ds: dict):
    # unpack
    y, X, Z = ds['y'], ds['X'], ds['Z'] 
    ffx, rfx = ds['ffx'], ds['rfx']
    sigma_eps, sigma_rfx = ds['sigma_eps'], ds['sigma_rfx']
    
    # data
    y_, mean_y, std_y = standardize(y, (1,2))
    X_, mean_X, std_X = standardize(X, (1,2))
    X_[..., 0] = 1.
    Z_, mean_Z, std_Z = standardize(Z, (1,2))
    Z_[..., 0] = 1.
    
    # fixed effects
    mean_Xbeta = (mean_X.view(b, d+1)[:, 1:] * ffx[:, 1:]).sum(-1)
    ffx_ = ffx * std_X.view(b, d+1) / std_y.view(b, 1)
    ffx_[:, 0] = (ffx[:, 0] + mean_Xbeta - mean_y.view(b)) / std_y.view(b)
    # random effects
    mean_Zb = (mean_Z.view(b, 1, q)[..., 1:] * rfx[..., 1:]).sum(-1)
    rfx_ = rfx * std_Z.view(b, 1, q) / std_y
    rfx_[..., 0] = (rfx[..., 0] + mean_Zb) / std_y.view(b,1)
    
    # variance parameters
    sigma_eps_ = sigma_eps / std_y.view(b)
    sigma_rfx_ = sigma_rfx * std_Z.view(b,q) / std_y.view(b,1)
    # random intercept gets a special treatment:
    cov = batch_covariance(mean_Z.view(b, 1, q) * rfx) # covariance matrix of weighted rfx
    cov_sum = cov.sum(dim=(-1,-2)) - sigma_rfx[:, 0].square()
    sigma_rfx_[:, 0] = (sigma_rfx[:, 0].square() + cov_sum).sqrt() / std_y.view(b)
    
    # y_hat_
    mu_g = torch.einsum('bmnd,bd->bmn', X_, ffx_)
    mu_l = torch.einsum('bmnq,bmq->bmn', Z_, rfx_)
    y_hat_ = mu_g + mu_l
    eps_ = y_ - y_hat_
    
    # fixed effects recovered
    ffx_r = ffx_ / std_X.view(b, d+1) * std_y.view(b, 1)
    ffx_r[:, 0] = ffx_[:, 0] * std_y.view(b) + mean_y.view(b) - mean_Xbeta
    rfx_r = rfx_ / std_Z.view(b, 1, q) * std_y
    rfx_r[..., 0] = rfx_[..., 0] * std_y.view(b,1) - mean_Zb
    mu_g = torch.einsum('bmnd,bd->bmn', X, ffx_r)
    mu_l = torch.einsum('bmnq,bmq->bmn', Z, rfx_r)
    y_hat_r = mu_g + mu_l
    
    # check recovery
    rmse = (sigma_rfx_ - rfx_.std(1)).square().mean(0).sqrt()
    for i in range(q):
        print(f'RMSE (sigma {i}): {rmse[i]:.3f}')
    
    rmse = (sigma_eps_ - eps_.std((1,2))).square().mean().sqrt()
    print(f'RMSE (sigma error): {rmse:.3f}')
    
    # export
    out = {
        'ffx_': ffx_, 'ffx_r': ffx_r,
        'rfx_': rfx_, 'rfx_r': rfx_r,
        'y_hat_': y_hat_, 'y_hat_r': y_hat_r,
        'X_': X_, 'y_': y_,
        }
    ds.update(out)
    
    

# -----------------------------------------------------------------------------
# Plotting functions

def _plotXy(ax, x, y, fx, title=''):
    ax.plot(x, y, '.')
    if fx.dim() > 1:
        m = fx.shape[0]
        colors = [cmap(i) for i in range(m)]  # one color per grou[]
        for i in range(m):
            ax.axline((0, fx[i,0]), slope=fx[i,1], color=colors[i], linestyle='--')
    else:
        ax.axline((0, fx[0]), slope=fx[1])
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.axhline(linewidth=1, color="k")
    ax.axvline(linewidth=1, color="k")
    ax.grid()
    ax.set_title(title)
    
    
def plotXyFFX(ds: dict, b=0):
    fig, axs = plt.subplots(1,2)
    
    # original
    x = ds['X'][b, :, 1]
    y = ds['y'][b]
    ffx = ds['ffx'][b]
    _plotXy(axs[0], x, y, ffx, 'original')
    
    # standardized
    x = ds['X_'][b, :, 1]
    y = ds['y_'][b]
    ffx = ds['ffx_'][b]
    _plotXy(axs[1], x, y, ffx, 'standardized')
    fig.tight_layout()
    

def plotXyMFX(ds: dict, b=0):
    fig, axs = plt.subplots(1,2)
    
    # original
    m = len(ds['X'][0])
    x = ds['X'][b, ..., 1].permute(1,0)
    y = ds['y'][b].permute(1,0)
    ffx = ds['ffx'][b].unsqueeze(0).expand(m, d+1)
    rfx = ds['rfx'][b]
    mfx = ffx.clone()
    mfx[:, :q] += rfx
    _plotXy(axs[0], x, y, mfx, 'original')
    
    # standardized
    x = ds['X_'][b, ..., 1].permute(1,0)
    y = ds['y_'][b].permute(1,0)
    ffx = ds['ffx_'][b].unsqueeze(0).expand(m, d+1)
    rfx = ds['rfx_'][b]
    mfx = ffx.clone()
    mfx[:, :q] += rfx
    _plotXy(axs[1], x, y, mfx, 'standardized')
    

def plotParams(ds: dict, d: int, prefix='ffx'):
    fig, axs = plt.subplots(1, d+1)
    true = ds[prefix]
    recovered = ds[prefix + '_r']
    for i in range(d+1):
        axs[i].plot(true[:, i], recovered[:, i], '.')
        axs[i].set_xlabel('true')
        axs[i].set_ylabel('recovered')
        axs[i].grid()
        axs[i].set_title(f'theta {i}')
    fig.suptitle('parameter recovery')
    fig.tight_layout()
        
    
def plotPred(ax, y: torch.Tensor, y_hat: torch.Tensor, title=''):
    ax.plot(y_hat, y, '.')
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    ax.axline((0, 0), slope=1, color='gray', zorder=0, linestyle='--')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid()
    ax.set_xlabel('y')
    ax.set_ylabel('predicted')
    ax.set_title(title)
    
    
def plotPredictions(ds: dict, i: int):
    fig, axs = plt.subplots(1, 3)
    plotPred(axs[0], ds['y'][i], ds['y_hat'][i],  title='original')
    plotPred(axs[1], ds['y_'][i], ds['y_hat_'][i],  title='standardized')
    plotPred(axs[2], ds['y'][i], ds['y_hat_r'][i],  title='recovered')
    fig.suptitle('prediction recovery')
    fig.tight_layout()
    
    
###############################################################################   
if __name__ == '__main__':
    
    cmap = plt.get_cmap('tab10')
    # torch.manual_seed(0)
    b = 64
    d = 2
    q = d+1
    m = 5
    
    # FFX
    ds = datasetFfx(b=b, n=50, d=d)
    standardizeFfx(ds)
    if d==1:
        plotXyFFX(ds)
    plotParams(ds, d)
    plotPredictions(ds, 1)
    compareFFX(ds)
    
    # # MFX
    # ds = datasetMfx(b=b, m=m, n=50, d=d, q=q)
    # standardizeMfx(ds)
    # if d==1:
    #     plotXyMFX(ds)
    # plotParams(ds, d)
    # plotParams(ds, d, prefix='rfx')
    # plotPredictions(ds, 1)

