import numpy as np


def bin_buffer(replay, n):
    bins = np.linspace(-0.24, 0.24, n)
    datas = [[[] for _ in range(n)] for _ in range(n)]
    for transition in replay.all():
        s = transition[0].flatten()
        x, y = transition[-1][1]
        if -0.24 <= x <= 0.24 and -0.24 <= y <= 0.24:
            ix = np.digitize(x, bins)
            iy = np.digitize(y, bins)
            datas[iy][ix].append(s)
    datas = [[np.array(data) for data in row] for row in datas]
    return datas, bins


def plot_covariance_dqn(sf, replay, test_w, i_task, iter):
    
    # separate data by x/y values
    n = 40
    datas, bins = bin_buffer(replay, n)
    
    # compute mean and variance
    grids_mean = [[np.zeros((n, n)) for _ in range(sf.n_tasks)] for _ in range(sf.n_tasks + len(test_w))]
    grids_var = [[np.zeros((n, n)) for _ in range(sf.n_tasks)] for _ in range(sf.n_tasks + len(test_w))]
    for ir, data_r in enumerate(datas):
        for ic, data_c in enumerate(data_r):
            if data_c.size > 0:
                for i_policy in range(sf.n_tasks):
                    mu, Sigma = sf.get_successor(data_c, i_policy)
                    for j, w in enumerate(sf.fit_w + list(test_w)):
                        grids_mean[j][i_policy][ir, ic] = np.mean(mu @ w)
                        grids_var[j][i_policy][ir, ic] = np.mean(w.T @ Sigma @ w)
    
    # clip outliers
    for i in range(len(grids_mean)):
        for j in range(len(grids_mean[0])):
            grids_mean[i][j] = np.minimum(grids_mean[i][j], np.percentile(grids_mean[i][j], 99.))
            grids_var[i][j] = np.minimum(grids_var[i][j], np.percentile(grids_var[i][j], 99.))
    
    # combine
    grid_mean = np.vstack([np.hstack(row) for row in grids_mean])
    grid_var = np.vstack([np.hstack(row) for row in grids_var])
    np.savetxt('reacher_sfdqn_mean_{}.csv'.format(iter), grid_mean, delimiter=',')
    np.savetxt('reacher_sfdqn_var_{}.csv'.format(iter), grid_var, delimiter=',')


def plot_covariance_c51(sf, replay, test_w, i_task, iter):
    
    # separate data by x/y values
    n = 40
    datas, bins = bin_buffer(replay, n)
    
    # compute covariance and utility
    grids_mean = [[np.zeros((n, n)) for _ in range(sf.n_tasks)] for _ in range(sf.n_tasks + len(test_w))]
    grids_var = [[np.zeros((n, n)) for _ in range(sf.n_tasks)] for _ in range(sf.n_tasks + len(test_w))]
    for ir, data_r in enumerate(datas):
        for ic, data_c in enumerate(data_r):
            if data_c.size > 0:
                for i_policy in range(sf.n_tasks):
                    if sf.method == 'exact':
                        ws = sf.true_w + list(test_w)
                        qs = sf.compute_utilities(data_c, i_policy, ws)
                        for j, q in enumerate(qs):
                            grids_mean[j][i_policy][ir, ic] = np.mean(q)
                    else:
                        mu, Sigma = sf.compute_mean_variance(data_c, i_policy)
                        for j, w in enumerate(sf.fit_w + list(test_w)):
                            grids_mean[j][i_policy][ir, ic] = np.mean(mu @ w)
                            var = Sigma @ (w ** 2)
                            if sf.method == 'gauss':
                                grids_var[j][i_policy][ir, ic] = np.mean(var)
                            else:
                                grids_var[j][i_policy][ir, ic] = np.mean(
                                    (-1. / sf.risk_aversion) * np.log(np.maximum(
                                        1. - 0.5 * sf.risk_aversion ** 2 * var, 1e-14)))
    
    # clip outliers
    for i in range(len(grids_mean)):
        for j in range(len(grids_mean[0])):
            grids_mean[i][j] = np.minimum(grids_mean[i][j], np.percentile(grids_mean[i][j], 99.))
            grids_var[i][j] = np.minimum(grids_var[i][j], np.percentile(grids_var[i][j], 99.))
    
    # combine
    grid_mean = np.vstack([np.hstack(row) for row in grids_mean])
    grid_var = np.vstack([np.hstack(row) for row in grids_var])
    np.savetxt('reacher_sfc51_mean_{}.csv'.format(iter), grid_mean, delimiter=',')
    np.savetxt('reacher_sfc51_var_{}.csv'.format(iter), grid_var, delimiter=',')
    
