# visualize posterior devolopment over n (conditional on single dataset):
# - evaluate quantiles [✓]
# - subset dataset and run model on each [✓]
# - plot quantiles [✓]

import torch
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
import matplotlib.colors as colors
from metabeta.models.approximators import Approximator


# -----------------------------------------------------------------------------
# plot over n
def subsetFFX(batch: dict, batch_idx: int = 0) -> dict:
    ''' for dataset {batch_idx} in batch,
        generate a new batch out of progressive subsamples '''
    # extract batch_idx
    ds = {k: v[batch_idx:batch_idx+1].clone() for k, v in batch.items()
          if isinstance(v, torch.Tensor)}

    # repeat all tensors n times
    n = int(ds['n'])
    ds = {k: v.repeat(n, *[1]*(v.ndim-1)) for k, v in ds.items()}

    # dynamically subset
    ns, mask_n =  ds['n'], ds['mask_n']
    X, y = ds['X'], ds['y']
    for i in range(n):
        ns[i] = i + 1
        mask_n[i, i+1:n] = False
        X[i, i+1:n] = torch.zeros_like(X[i, i+1:n])
        y[i, i+1:n] = torch.zeros_like(y[i, i+1:n])
    ds.update(dict(n=ns, mask_n=mask_n, X=X, y=y))
    return ds
        
        
def subsetMFX(batch: dict, batch_idx: int = 0) -> dict:
    # extract batch_idx
    ds = {k: v[batch_idx:batch_idx+1].clone() for k, v in batch.items()
          if isinstance(v, torch.Tensor)}

    # repeat all tensors n times
    n = ds['X'].shape[-2]
    ds = {k: v.repeat(n, *[1]*(v.ndim-1)) for k, v in ds.items()}

    # dynamically subset
    ns, mask_n =  ds['n_i'], ds['mask_n']
    X, y = ds['X'], ds['y']
    for i in range(n):
        i_ = i + 1 + torch.zeros_like(ns[i])
        ns[i] = torch.min(i_, ns[i])
        mask_n[i,:, i+1:n] = False
        X[i, :, i+1:n] = torch.zeros_like(X[i, :, i+1:n])
        y[i, :, i+1:n] = torch.zeros_like(y[i, :, i+1:n])        
    ds.update(dict(n=ns, mask_n=mask_n, X=X, y=y))
    return ds


def plotOverN(quantiles: torch.Tensor, targets: torch.Tensor, names) -> None:
    # prepare targets and quantiles
    if quantiles.shape[1] == targets.shape[1] + 1:
        quantiles = quantiles[:, :-1]
    target = targets[0]
    mask = (target != 0.)
    target_ = target[mask]
    quantiles_ = quantiles[:, mask]
    names_ = names[mask.numpy()]

    # prepare colors and x axiscolors = [cmap(i) for i in range(cmap.N)]
    d = int(mask.sum())
    ns = torch.arange(1, quantiles.shape[0]+1)
    # min_val = float(torch.tensor([quantiles.min(), quantiles.min()]).min())
    # max_val = float(torch.tensor([quantiles.max(), quantiles.max()]).max())

    _, ax = plt.subplots(figsize=(8, 6))
    for i in range(d):
        color = colors[i]
        quantiles_i = quantiles_[:, i]
        ax.plot(ns, quantiles_i[..., 1], label=names_[i], color=color)
        ax.fill_between(ns, quantiles_i[..., 0], quantiles_i[..., 2],
                        color=color, alpha=0.15)
        ax.axhline(y=target_[i], color=color, linestyle=':', linewidth=1.5) # type: ignore

    # Adding labels and title
    ax.set_xlabel('n')  # X-axis label
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    # ax.set_ylim(max(-7.5, min_val), min(7.5, max_val))
    ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
    ax.grid(True)


def unfold(model: Approximator, full: dict, batch_idx: int) -> None:
        subset = subsetMFX if cfg.fx_type == 'mfx' else subsetFFX
        batch = subset(full, batch_idx)
        _, proposed, _ = model(batch, sample=True)
        quantiles = model.quantiles(proposed['global'], [.025, .500, .975])
        targets = model.targets(batch)
        names = model.names(batch)
        plotOverN(quantiles, targets, names)
        
        
def plotOverT(time: torch.Tensor, losses: torch.Tensor,
              q: list = [.025, .500, .975], kl: bool = False):
    # time: (n_iter) losses: (n_iter, batch)
    # center = losses.mean(-1)
    # std = losses.std(-1)
    # lower, upper = center - std, center + std
    lower = torch.quantile(losses, q[0], dim=-1)
    center = torch.quantile(losses, q[1], dim=-1)
    upper = torch.quantile(losses, q[2], dim=-1)
    _, ax = plt.subplots(figsize=(8, 6))
    ax.plot(time, center, color='darkgreen')
    ax.fill_between(time, lower, upper, color='darkgreen', alpha=0.3)
    # ax.set_xticks(time)
    ax.set_xlabel('datasets [10k]')
    ylabel = 'D(Optimal | Model)' if kl else '-log p(theta)'
    ax.set_ylabel(ylabel)
    ax.grid(True) 