import matplotlib.pyplot as plt
import sklearn.neighbors

from models import Poisson2D, RBF, Cauchy, Kernel, make_rbf, make_aniso_rbf, Convection
import jax
import jax.numpy as jnp
import jax.random as jr
import cheb
import tqdm
import pickle
import yaml
import os
import yaml
import scipy as sp
import scipy.optimize
from tqdm import tqdm
from collections import defaultdict


def retrieve_samples(config, key=None):
    if key is None:
        key = config['key']
    samples_cfg = config['train']['samples']
    N = samples_cfg['N']
    dim = config['dim']
    dist = samples_cfg['dist']

    if dist == 'uniform':
        samples = 2 * (jr.uniform(shape=(N, dim), key=key) - 0.5)
        return samples
    if dist == 'chebyshev':
        _, _, samples, _= retrieve_grid(config, eval=False)
        return samples.reshape((-1, dim))
    else:
        raise NotImplementedError('This distribution is not implemented')



def retrieve_op(config):
    id = config['train']['truth']['id'].lower()
    if id == "poisson-exa":
        return Poisson2D()
    elif id == "convection":
        return Convection(boundary_condition=lambda x: jnp.sin(-jnp.pi * x), beta=-15/jnp.pi)
    else:
        raise NotImplementedError("This operator is not implemented")


def retrieve_grid(config, eval=False):
    if eval:
        N_grid = config['eval_gridpts']
    else:
        N_grid = config['train']['samples']['obs_gridpts']

    assert config['dim'] == 2, "dimension must be two in current implementation"
    ax_grid, weights = cheb.gridpts(N_grid, with_weights=True)
    dim_grid = jnp.stack(jnp.meshgrid(ax_grid, ax_grid), axis=-1)
    dim_weights = jnp.kron(weights, weights)
    return ax_grid, weights, dim_grid, dim_weights


def retrieve_model(config, key=None):
    if key is None:
        key = config['key']
    if config["model"]["id"].lower() == "mlp":
        #return make_mlp(dims=config["model"]["dims"], key=key)
        raise ValueError("IMPLEMENT OTHER KERNELS NEXT")
    elif config["model"]["id"].lower() == "rbf":
        return make_rbf(bandwidth=config["model"]["bandwidth"], operator=retrieve_op(config))
    elif config['model']['id'].lower() == 'aniso-rbf':
        return make_aniso_rbf(bandwidth=config['model']['bandwidth'], theta=config['model']['theta'],
                              scale=config['model']['scale'], operator=retrieve_op(config))
    else:
        raise ValueError("IMPLEMENT OTHER KERNELS NEXT")


class Dataset():
    def __init__(self, config, operator):
        self.config = config
        key = config['key']
        self.X_sample = retrieve_samples(config)

        _, _, dim_grid, W = retrieve_grid(config)
        _, _, dim_grid_eval, W_eval = retrieve_grid(config, eval=True)

        self.Z_grid = dim_grid.reshape((-1, 2))
        self.Z_grid_weights = W
        self.Z_grid_eval = dim_grid_eval.reshape((-1, 2))
        self.Z_grid_weights_eval = W_eval

        self.N = len(self.X_sample)
        self.M = len(self.Z_grid)

        data_noise_var = config['train']['truth']['data-noise']
        pinn_noise_var = config['train']['truth']['deriv-noise']

        self.Y_sample = operator.eval_solution(self.X_sample)
        self.Y_sample_noisy = self.Y_sample + jnp.sqrt(data_noise_var) * jr.normal(key, shape=(self.N, 1))
        self.Y_grid = operator.eval_solution(self.Z_grid)
        self.Y_grid_eval = operator.eval_solution(self.Z_grid_eval)
        self.R_grid = operator.eval_forcing(self.Z_grid)
        self.R_grid_noisy = self.R_grid + jnp.sqrt(pinn_noise_var) * jr.normal(key, shape=(self.M, 1))
        self.R_grid_eval = operator.eval_forcing(self.Z_grid_eval)

def linear_diagnostics(dataset, kernel, config, param_overwrite=None, verbose=False):
    X = dataset.X_sample
    Y = dataset.Y_sample_noisy
    Z = dataset.Z_grid
    Z_eval = dataset.Z_grid_eval
    R = dataset.R_grid_noisy
    W = dataset.Z_grid_weights
    W_eval = dataset.Z_grid_weights_eval

    R_precision_coeffs = 1/W
    R_precision_coeffs_eval = 1/W_eval
    Y_precision_coeffs = 1/len(X) * jnp.ones((len(X),)) if dataset.config['train']['samples']['dist'] == 'uniform' else 1/W
    Y_precision_coeffs_eval = 1/W_eval

    Y_tru_samples = dataset.Y_sample
    Y_tru_obs_grid = dataset.Y_grid
    Y_tru = dataset.Y_grid_eval
    Y_tru_norm = jnp.sum(W_eval * Y_tru.flatten() ** 2)
    R_tru = dataset.R_grid_eval
    R_tru_norm = jnp.sum(W_eval * R_tru.flatten() ** 2)

    Kxx = kernel.K(X, X)
    Hxz = kernel.H(X, Z)
    Gzz = kernel.G(Z, Z)

    Kxxb = kernel.K(X, Z_eval)
    Hxbz = kernel.H(Z_eval, Z)
    Hxzb = kernel.H(X, Z_eval)
    Gzzb = kernel.G(Z, Z_eval)
    Kxbxb = kernel.K(Z_eval, Z_eval)
    Hxbzb = kernel.H(Z_eval, Z_eval)
    Gzbzb = kernel.G(Z_eval, Z_eval)

    gma = config['train']['reg']['DATA']
    rho = config['train']['reg']['PINN']
    eta = config['train']['reg']['NORM']
    N = dataset.N
    M = dataset.M

    results = []
    if param_overwrite is None:
        cnfs = [(gma, rho, eta, N, M)]
    else:
        l = lambda a, b: [a] if len(b) == 0 else b  # this is stupid
        _param_overwrite = defaultdict(list, param_overwrite)  # this is stupid
        cnfs = [(_gma, _rho, _eta, _N, _M) for _gma in l(gma, _param_overwrite['gma'])
                for _rho in l(rho, _param_overwrite['rho'])
                for _eta in l(eta, _param_overwrite['eta'])
                for _N in l(N, _param_overwrite['N'])
                for _M in l(M, _param_overwrite['M'])]

    if verbose:
        pbar = tqdm(total=len(cnfs))

    for (gma, rho, eta, N, M) in cnfs:
        def _pile(_eta):
            cov_block = jnp.block([[Kxx + (gma / _eta) * jnp.diag(Y_precision_coeffs), Hxz],
                                   [Hxz.T, Gzz + (rho / _eta) * jnp.diag(R_precision_coeffs)]])
            obs_block = jnp.concatenate((Y, R), axis=0)

            obs_block_inv = jnp.linalg.lstsq(cov_block, obs_block)[0]

            _norm = (0.5 / _eta) * jnp.sum(obs_block * obs_block_inv)
            _det = 0.5 * jnp.linalg.slogdet(cov_block)[1] + (N + M) * 0.5 * jnp.log(2 * jnp.pi * _eta)
            return (_norm + _det) / (N + M)
        def _diagnostics(_eta):
            cov_block = jnp.block([[Kxx + (gma / _eta) * jnp.diag(Y_precision_coeffs), Hxz],
                                   [Hxz.T, Gzz + (rho / _eta) * jnp.diag(R_precision_coeffs)]])  # ATTN
            obs_block = jnp.concatenate((Y, R), axis=0)
            tru_block = jnp.concatenate((Y_tru, R_tru), axis=0)

            obs_block_inv = jnp.linalg.lstsq(cov_block, obs_block)[0]

            # compute PILE
            _norm = (0.5 / _eta) * jnp.sum(obs_block * obs_block_inv)
            _det = 0.5 * jnp.linalg.slogdet(cov_block)[1] + (N + M) * 0.5 * jnp.log(2 * jnp.pi * _eta)
            PILE = (_norm + _det) / (N + M)

            # compute posterior generalization error
            cov_block_cross = jnp.block([[Kxxb.T, Hxbz],
                                         [Hxzb.T, Gzzb.T]])  # Gzz.T is redundant - it's here to match the math and avoid issues if I generalize this impl
            cov_block_grid = jnp.block([[Kxbxb, Hxbzb], [Hxbzb.T, Gzbzb]])
            cov_block_var_reduction = cov_block_cross @ jnp.linalg.lstsq(cov_block, cov_block_cross.T)[0]

            m_ppd = cov_block_cross @ obs_block_inv
            S_ppd = _eta * (cov_block_grid - cov_block_var_reduction)

            M_eval = len(Z_eval)
            sq_posterior_var = ((m_ppd - tru_block) ** 2).flatten() + jnp.diag(S_ppd)
            data_abs_var = jnp.sum(W_eval * sq_posterior_var[:M_eval])
            phys_abs_var = jnp.sum(W_eval * sq_posterior_var[M_eval:])
            data_rel_var = jnp.sum(W_eval * sq_posterior_var[:M_eval]) / Y_tru_norm
            phys_rel_var = jnp.sum(W_eval * sq_posterior_var[M_eval:]) / R_tru_norm

            sq_err = ((m_ppd - tru_block) ** 2).flatten()
            data_gen_err = jnp.sum(W_eval * sq_err[:M_eval])
            phys_gen_err = jnp.sum(W_eval * sq_err[M_eval:])
            if jnp.isnan(PILE).any():
                print("Oh no!")
            return PILE, _det, _norm, data_rel_var, data_abs_var, data_gen_err, phys_rel_var, phys_abs_var, phys_gen_err, m_ppd, S_ppd

        # compute generalization error here
        if eta == 0:
            # liam conjecture: 2 * pi * stuff in front of log = eta*
            res = sp.optimize.minimize_scalar(lambda x: _pile(x), bounds=(1, 5 * (M + N)), method="bounded",
                                              options={"maxiter": 50})
            eta = res['x']

        PILE, _det, _norm, data_rel_var, data_abs_var, data_gen_err, phys_rel_var, phys_abs_var, phys_gen_err, m_ppd, S_ppd = _diagnostics(eta)

        results.append({'gma': gma, 'rho': rho, 'eta': eta, 'N': N, 'M': M, 'PILE': PILE,
                        'PILE_det': _det, 'PILE_norm': _norm, 'phys_rel_var': phys_rel_var, 'data_rel_var': data_rel_var,
                        'phys_gen_err': phys_gen_err, 'data_gen_err': data_gen_err, 'data_abs_var': data_abs_var, 'phys_abs_var': phys_abs_var,
                        'm_ppd': m_ppd, 'S_ppd': S_ppd})

        if verbose:
            pbar.update(1)

    if param_overwrite is None:
        return results[0]
    else:
        return results




if __name__ == "__main__":
    # verify that the above can successfully learn from samples a good solution w/ RBF
    with open('cnfs/template.yml') as f:
        cnf = yaml.safe_load(f.read())
    cnf["key"] = jr.key(cnf["prng_seed"])
    N = cnf['train']['samples']['obs_gridpts']
    M_obs = cnf['train']['samples']['obs_gridpts']
    M_eval = cnf['eval_gridpts']
    cnf['model']['bandwidth'] = 1.2 * jnp.pi**2 / (2 * N**2)

    operator = retrieve_op(cnf)
    kernel = retrieve_model(cnf)
    dataset = Dataset(cnf, operator)

    conds = []
    bwidths = jnp.geomspace(0.5, 10, 10)
    for h in bwidths:
        cnf['model']['bandwidth'] = h * jnp.pi**2 / (2 * N**2)
        kernel = retrieve_model(cnf)
        Kxx = kernel.K(dataset.X_sample, dataset.X_sample)
        conds.append(jnp.linalg.cond(Kxx))
        print(f'h={h}, cond={conds[-1]}')
    import matplotlib.pyplot as plt
    plt.plot(bwidths, jnp.log(jnp.array(conds)))
    plt.show()
    # h = 10 appears to work nicely 


    result = linear_diagnostics(dataset, kernel, cnf, param_overwrite={'eta': [1], 'gma': [0.001], 'rho': [0.001]}, verbose=True)
    print(result)

